]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml-backend v2 : add ggml_backend_sched (#586)
authorslaren <redacted>
Mon, 30 Oct 2023 20:28:09 +0000 (21:28 +0100)
committerGitHub <redacted>
Mon, 30 Oct 2023 20:28:09 +0000 (22:28 +0200)
* ggml-backend-v2 wip

* fix metal build

* ggml-alloc : use a real backend buffer in measure mode

* backend sched : ignore view ops to reduce the number of splits

* dynamic ggml_cgraph wip

* dyn graphs : remove n_tasks from ggml_cplan

* dyn graphs : update ggml_graph_import

* reset hash table in ggml_build_forward

* ggml-alloc : split into tensor and graph allocators

* add ggml_backend_sched_set_node_backend

* remove ggml_build_forward_ctx, ggml_build_backward_ctx
add ggml_opt_params::graph_size
add ggml_new_graph_custom, ggml_graph_overhead_custom
add ggml_graph_clear

* update examples and tests, fix issues

* update more examples

* update gpt-2/main-backend.cpp from master

* ggml : fix copmile warning

* ci : update yolo, fix mnist, use gpt-2-backend

* ggml : fix uninit warning

* ci : switch to gpt-2-backend2

ggml-ci

* metal : skip noops early to avoid warnings from ggml_metal_get_buffer

---------

Co-authored-by: Georgi Gerganov <redacted>
42 files changed:
ci/run.sh
examples/dolly-v2/main.cpp
examples/gpt-2/CMakeLists.txt
examples/gpt-2/main-alloc.cpp [new file with mode: 0644]
examples/gpt-2/main-backend.cpp [new file with mode: 0644]
examples/gpt-2/main-batched.cpp
examples/gpt-2/main-ctx.cpp [new file with mode: 0644]
examples/gpt-2/main.cpp
examples/gpt-j/main.cpp
examples/gpt-neox/main.cpp
examples/mnist/main-cnn.cpp
examples/mnist/main-cpu.cpp
examples/mnist/main-mtl.cpp
examples/mnist/main.cpp
examples/mpt/main.cpp
examples/replit/main.cpp
examples/sam/main.cpp
examples/starcoder/main.cpp
examples/starcoder/starcoder-mmap.cpp
examples/whisper/whisper.cpp
examples/yolo/yolov3-tiny.cpp
include/ggml/ggml-alloc.h
include/ggml/ggml-backend.h
include/ggml/ggml.h
src/CMakeLists.txt
src/ggml-alloc.c
src/ggml-backend-impl.h [new file with mode: 0644]
src/ggml-backend.c
src/ggml-cuda.cu
src/ggml-impl.h [new file with mode: 0644]
src/ggml-metal.m
src/ggml.c
tests/test-blas0.c
tests/test-conv-transpose.c
tests/test-customop.c
tests/test-grad0.cpp
tests/test-mul-mat0.c
tests/test-opt.cpp
tests/test-pool.c
tests/test-rel-pos.c
tests/test-xpos.c
tests/test1.c

index 6b12b4aecf3090b72da10365ba1576b8c7d2de77..7afe8304fe402170a694f02c5de229d62ca8b1bd 100644 (file)
--- a/ci/run.sh
+++ b/ci/run.sh
@@ -145,8 +145,8 @@ function gg_run_gpt_2 {
     model="../models-mnt/gpt-2/ggml-model-gpt-2-117M.bin"
     prompts="../examples/prompts/gpt-2.txt"
 
-    (time ./bin/gpt-2 --model ${model} -s 1234 -n 64 -tt ${prompts}                       ) 2>&1 | tee -a $OUT/${ci}-tg.log
-    (time ./bin/gpt-2 --model ${model} -s 1234 -n 64 -p "I believe the meaning of life is") 2>&1 | tee -a $OUT/${ci}-tg.log
+    (time ./bin/gpt-2-backend2 --model ${model} -s 1234 -n 64 -tt ${prompts}                       ) 2>&1 | tee -a $OUT/${ci}-tg.log
+    (time ./bin/gpt-2-backend2 --model ${model} -s 1234 -n 64 -p "I believe the meaning of life is") 2>&1 | tee -a $OUT/${ci}-tg.log
 
     (time ./bin/gpt-2-batched --model ${model} -s 1234 -n 64 -np 8 -p "I believe the meaning of life is") 2>&1 | tee -a $OUT/${ci}-tg.log
 
index 1a10880bb551336a49389714e47865212215deb5..cecc5626be90094d803f3b059f1e11c997822de2 100644 (file)
@@ -497,7 +497,7 @@ bool dollyv2_eval(
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = { };
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     // KQ_pos - contains the positions
     struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
@@ -555,8 +555,8 @@ bool dollyv2_eval(
                         (   n_ctx)*ggml_element_size(model.memory_v),
                         (il*n_ctx)*ggml_element_size(model.memory_v)*n_embd + n_past*ggml_element_size(model.memory_v));
 
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
             }
 
             // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
@@ -666,8 +666,8 @@ bool dollyv2_eval(
     //inpL = ggml_soft_max_inplace(ctx0, inpL);
 
     // run the computation
-    ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_build_forward_expand(gf, inpL);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
index af9cb4ef96c1c94c3b1a74776e78b64c0b5f63c5..91f15f0f1c7ce234c8e14295c2b0c8399a25611d 100644 (file)
@@ -1,7 +1,19 @@
 #
 # gpt-2
 
-set(TEST_TARGET gpt-2)
+set(TEST_TARGET gpt-2-ctx)
+add_executable(${TEST_TARGET} main-ctx.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
+
+set(TEST_TARGET gpt-2-alloc)
+add_executable(${TEST_TARGET} main-alloc.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
+
+set(TEST_TARGET gpt-2-backend)
+add_executable(${TEST_TARGET} main-backend.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
+
+set(TEST_TARGET gpt-2-backend2)
 add_executable(${TEST_TARGET} main.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
 
diff --git a/examples/gpt-2/main-alloc.cpp b/examples/gpt-2/main-alloc.cpp
new file mode 100644 (file)
index 0000000..c4bf986
--- /dev/null
@@ -0,0 +1,892 @@
+#include "ggml/ggml.h"
+#include "ggml/ggml-alloc.h"
+
+#include "common.h"
+#include "common-ggml.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+// default hparams (GPT-2 117M)
+struct gpt2_hparams {
+    int32_t n_vocab = 50257;
+    int32_t n_ctx   = 1024;
+    int32_t n_embd  = 768;
+    int32_t n_head  = 12;
+    int32_t n_layer = 12;
+    int32_t ftype   = 1;
+    float   eps     = 1e-5f;
+};
+
+struct gpt2_layer {
+    // normalization
+    struct ggml_tensor * ln_1_g;
+    struct ggml_tensor * ln_1_b;
+
+    struct ggml_tensor * ln_2_g;
+    struct ggml_tensor * ln_2_b;
+
+    // attention
+    struct ggml_tensor * c_attn_attn_w;
+    struct ggml_tensor * c_attn_attn_b;
+
+    struct ggml_tensor * c_attn_proj_w;
+    struct ggml_tensor * c_attn_proj_b;
+
+    // mlp
+    struct ggml_tensor * c_mlp_fc_w;
+    struct ggml_tensor * c_mlp_fc_b;
+
+    struct ggml_tensor * c_mlp_proj_w;
+    struct ggml_tensor * c_mlp_proj_b;
+};
+
+struct gpt2_model {
+    gpt2_hparams hparams;
+
+    // normalization
+    struct ggml_tensor * ln_f_g;
+    struct ggml_tensor * ln_f_b;
+
+    struct ggml_tensor * wte;     // position embedding
+    struct ggml_tensor * wpe;     //    token embedding
+    struct ggml_tensor * lm_head; // language model head
+
+    std::vector<gpt2_layer> layers;
+
+    // key + value memory
+    struct ggml_tensor * memory_k;
+    struct ggml_tensor * memory_v;
+
+    //
+    struct ggml_context * ctx;
+    std::map<std::string, struct ggml_tensor *> tensors;
+};
+
+// load the model's weights from a file
+bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
+    printf("%s: loading model from '%s'\n", __func__, fname.c_str());
+
+    auto fin = std::ifstream(fname, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        fin.read((char *) &magic, sizeof(magic));
+        if (magic != GGML_FILE_MAGIC) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            return false;
+        }
+    }
+
+    // load hparams
+    {
+        auto & hparams = model.hparams;
+
+        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fin.read((char *) &hparams.ftype,   sizeof(hparams.ftype));
+
+        const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
+
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        printf("%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        printf("%s: n_head  = %d\n", __func__, hparams.n_head);
+        printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+        printf("%s: ftype   = %d\n", __func__, hparams.ftype);
+        printf("%s: qntvr   = %d\n", __func__, qntvr);
+
+        hparams.ftype %= GGML_QNT_VERSION_FACTOR;
+    }
+
+    // load vocab
+    {
+        int32_t n_vocab = 0;
+        fin.read((char *) &n_vocab, sizeof(n_vocab));
+
+        if (n_vocab != model.hparams.n_vocab) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
+            return false;
+        }
+
+        std::string word;
+        std::vector<char> buf(128);
+
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            fin.read((char *) &len, sizeof(len));
+
+            buf.resize(len);
+            fin.read((char *) buf.data(), len);
+            word.assign(buf.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+        }
+    }
+
+    // for the big tensors, we have the option to store the data in 16-bit floats or quantized
+    // in order to save memory and also to speed up the computation
+    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
+    if (wtype == GGML_TYPE_COUNT) {
+        fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
+                __func__, fname.c_str(), model.hparams.ftype);
+        return false;
+    }
+
+    auto & ctx = model.ctx;
+
+    size_t ctx_size = 0;
+
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
+        ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
+
+        ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // wte
+        ctx_size +=   n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
+        ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // lm_head
+
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
+
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
+
+        ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype));         // c_attn_attn_w
+        ctx_size += n_layer*(       3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
+
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype));           // c_attn_proj_w
+        ctx_size += n_layer*(       n_embd*ggml_type_sizef(GGML_TYPE_F32));   // c_attn_proj_b
+
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_fc_w
+        ctx_size += n_layer*(       4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
+
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_proj_w
+        ctx_size += n_layer*(         n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
+
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
+
+        ctx_size += (6 + 12*n_layer)*512; // object overhead
+
+        printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
+        printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+    }
+
+    // create the ggml context
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ ctx_size,
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ false,
+        };
+
+        model.ctx = ggml_init(params);
+        if (!model.ctx) {
+            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+            return false;
+        }
+    }
+
+    // prepare memory for the weights
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        model.layers.resize(n_layer);
+
+        model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+        model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+        model.wte     = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+        model.wpe     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
+        model.lm_head = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+
+        // map by name
+        model.tensors["model/ln_f/g"] = model.ln_f_g;
+        model.tensors["model/ln_f/b"] = model.ln_f_b;
+
+        model.tensors["model/wte"]     = model.wte;
+        model.tensors["model/wpe"]     = model.wpe;
+        model.tensors["model/lm_head"] = model.lm_head;
+
+        for (int i = 0; i < n_layer; ++i) {
+            auto & layer = model.layers[i];
+
+            layer.ln_1_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_1_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.ln_2_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_2_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, 3*n_embd);
+            layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
+
+            layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd);
+            layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.c_mlp_fc_w    = ggml_new_tensor_2d(ctx, wtype,           n_embd, 4*n_embd);
+            layer.c_mlp_fc_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
+
+            layer.c_mlp_proj_w  = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);
+            layer.c_mlp_proj_b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            // map by name
+            model.tensors["model/h" + std::to_string(i) + "/ln_1/g"]        = layer.ln_1_g;
+            model.tensors["model/h" + std::to_string(i) + "/ln_1/b"]        = layer.ln_1_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/ln_2/g"]        = layer.ln_2_g;
+            model.tensors["model/h" + std::to_string(i) + "/ln_2/b"]        = layer.ln_2_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"]    = layer.c_mlp_fc_w;
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"]    = layer.c_mlp_fc_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"]  = layer.c_mlp_proj_w;
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"]  = layer.c_mlp_proj_b;
+        }
+    }
+
+    // key + value memory
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+
+        const int n_mem      = n_layer*n_ctx;
+        const int n_elements = n_embd*n_mem;
+
+        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+
+        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
+
+        printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
+    }
+
+    // load weights
+    {
+        size_t total_size = 0;
+
+        bool has_lm_head = false;
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ttype;
+
+            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+            fin.read(reinterpret_cast<char *>(&ttype),  sizeof(ttype));
+
+            if (fin.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = { 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            fin.read(&name[0], length);
+
+            if (model.tensors.find(name) == model.tensors.end()) {
+                fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
+                return false;
+            }
+
+            auto tensor = model.tensors[name];
+            if (ggml_nelements(tensor) != nelements) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str());
+                return false;
+            }
+
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+                fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                        __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
+                return false;
+            }
+
+            // for debugging
+            if (0) {
+                printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
+            }
+
+            const size_t bpe = ggml_type_size(ggml_type(ttype));
+
+            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                        __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);
+                return false;
+            }
+
+            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+
+            // GPT-2 models share the WTE tensor as the LM head
+            if (name == "model/wte" && has_lm_head == false) {
+                memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
+            }
+
+            if (name == "model/lm_head") {
+                has_lm_head = true;
+            }
+
+            total_size += ggml_nbytes(tensor);
+        }
+
+        printf("%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
+    }
+
+    fin.close();
+
+    return true;
+}
+
+// build the computation graph
+struct ggml_cgraph * gpt2_graph(
+        const gpt2_model & model,
+        struct ggml_allocr * allocr,
+        const int n_past,
+        const std::vector<gpt_vocab::id> & embd_inp) {
+    const int N = embd_inp.size();
+
+    const auto & hparams = model.hparams;
+
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+    const int n_ctx   = hparams.n_ctx;
+    const int n_head  = hparams.n_head;
+
+    // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
+    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
+    static std::vector<uint8_t> buf(buf_size);
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ buf_size,
+        /*.mem_buffer =*/ buf.data(),
+        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    ggml_allocr_alloc(allocr, embd);
+
+    // avoid writing to tensors if we are only measuring the memory usage
+    if (!ggml_allocr_is_measure(allocr)) {
+        memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
+    }
+
+    struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    ggml_allocr_alloc(allocr, position);
+    if (!ggml_allocr_is_measure(allocr)) {
+        for (int i = 0; i < N; ++i) {
+            ((int32_t *) position->data)[i] = n_past + i;
+        }
+    }
+
+    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+    ggml_allocr_alloc(allocr, KQ_scale);
+    if (!ggml_allocr_is_measure(allocr)) {
+        ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
+    }
+
+    // wte + wpe
+    struct ggml_tensor * inpL =
+        ggml_add(ctx0,
+                ggml_get_rows(ctx0, model.wte, embd),
+                ggml_get_rows(ctx0, model.wpe, position));
+
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * cur;
+
+        // norm
+        {
+            // [ 768, N]
+            cur = ggml_norm(ctx0, inpL, hparams.eps);
+
+            // cur = ln_1_g*cur + ln_1_b
+            // [ 768, N]
+            cur = ggml_add(ctx0,
+                    ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
+                        cur),
+                    ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
+        }
+
+        // attn
+        // [2304, 768] - model.layers[il].c_attn_attn_w
+        // [2304,   1] - model.layers[il].c_attn_attn_b
+        // [ 768,   N] - cur (in)
+        // [2304,   N] - cur (out)
+        //
+        // cur = attn_w*cur + attn_b
+        // [2304, N]
+        {
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_attn_attn_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
+                    cur);
+        }
+
+        // self-attention
+        {
+            struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
+            struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
+            struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
+
+            // store key and value to memory
+            if (N >= 1) {
+                struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
+
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
+            // [64, N, 12]
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
+                            Qcur,
+                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
+                        0, 2, 1, 3);
+
+            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
+            // [64, n_past + N, 12]
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
+                            n_embd/n_head, n_head, n_past + N),
+                        0, 2, 1, 3);
+
+            // GG: flash attention
+            //struct ggml_tensor * V =
+            //    ggml_cpy(ctx0,
+            //            ggml_permute(ctx0,
+            //                ggml_reshape_3d(ctx0,
+            //                    ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+            //                    n_embd/n_head, n_head, n_past + N),
+            //                1, 2, 0, 3),
+            //            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
+
+            //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
+
+            // K * Q
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctx0,
+                        KQ,
+                        KQ_scale);
+
+            // KQ_masked = mask_past(KQ_scaled)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+
+            // KQ = soft_max(KQ_masked)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
+            // [n_past + N, 64, 12]
+            struct ggml_tensor * V_trans =
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+                                n_embd/n_head, n_head, n_past + N),
+                            1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
+
+            // KQV = transpose(V) * KQ_soft_max
+            // [64, N, 12]
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            // [64, 12, N]
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            // [768, N]
+            cur = ggml_cpy(ctx0,
+                    KQV_merged,
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+        }
+
+        // projection
+        // [ 768, 768] - model.layers[il].c_attn_proj_w
+        // [ 768,   1] - model.layers[il].c_attn_proj_b
+        // [ 768,   N] - cur (in)
+        // [ 768,   N] - cur (out)
+        //
+        // cur = proj_w*cur + proj_b
+        // [768, N]
+        {
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_attn_proj_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
+                    cur);
+        }
+
+        // add the input
+        cur = ggml_add(ctx0, cur, inpL);
+
+        struct ggml_tensor * inpFF = cur;
+
+        // feed-forward network
+        {
+            // norm
+            {
+                cur = ggml_norm(ctx0, inpFF, hparams.eps);
+
+                // cur = ln_2_g*cur + ln_2_b
+                // [ 768, N]
+                cur = ggml_add(ctx0,
+                        ggml_mul(ctx0,
+                            ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
+                            cur),
+                        ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
+            }
+
+            // fully connected
+            // [3072, 768] - model.layers[il].c_mlp_fc_w
+            // [3072,   1] - model.layers[il].c_mlp_fc_b
+            // [ 768,   N] - cur (in)
+            // [3072,   N] - cur (out)
+            //
+            // cur = fc_w*cur + fc_b
+            // [3072, N]
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_mlp_fc_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
+                    cur);
+
+            // GELU activation
+            // [3072, N]
+            cur = ggml_gelu(ctx0, cur);
+
+            // projection
+            // [ 768, 3072] - model.layers[il].c_mlp_proj_w
+            // [ 768,    1] - model.layers[il].c_mlp_proj_b
+            // [3072,    N] - cur (in)
+            // [ 768,    N] - cur (out)
+            //
+            // cur = proj_w*cur + proj_b
+            // [768, N]
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_mlp_proj_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
+                    cur);
+        }
+
+        // input for next layer
+        inpL = ggml_add(ctx0, cur, inpFF);
+    }
+
+    // norm
+    {
+        // [ 768, N]
+        inpL = ggml_norm(ctx0, inpL, hparams.eps);
+
+        // inpL = ln_f_g*inpL + ln_f_b
+        // [ 768, N]
+        inpL = ggml_add(ctx0,
+                ggml_mul(ctx0,
+                    ggml_repeat(ctx0, model.ln_f_g, inpL),
+                    inpL),
+                ggml_repeat(ctx0, model.ln_f_b, inpL));
+    }
+
+    // inpL = WTE * inpL
+    // [ 768, 50257] - model.lm_head
+    // [ 768, N]     - inpL
+    inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
+
+    // logits -> probs
+    //inpL = ggml_soft_max(ctx0, inpL);
+
+    ggml_build_forward_expand(gf, inpL);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
+// evaluate the transformer
+//
+//   - model:     the model
+//   - allocr:    ggml_allocr to use to allocate the compute buffer
+//   - n_threads: number of threads to use
+//   - n_past:    the context size so far
+//   - embd_inp:  the embeddings of the tokens in the context
+//   - embd_w:    the predicted logits for the next token
+//
+bool gpt2_eval(
+        const gpt2_model & model,
+        struct ggml_allocr * allocr,
+        const int n_threads,
+        const int n_past,
+        const std::vector<gpt_vocab::id> & embd_inp,
+              std::vector<float>         & embd_w) {
+    const int N = embd_inp.size();
+
+    const auto & hparams = model.hparams;
+
+    const int n_vocab = hparams.n_vocab;
+
+    // reset the allocator to free all the memory allocated during the previous inference
+    ggml_allocr_reset(allocr);
+
+    struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, embd_inp);
+
+    // allocate tensors
+    ggml_allocr_alloc_graph(allocr, gf);
+
+    // run the computation
+    struct ggml_cplan plan = ggml_graph_plan(gf, n_threads);
+    static std::vector<uint8_t> work_buffer;
+    work_buffer.resize(plan.work_size);
+    plan.work_data = work_buffer.data();
+    ggml_graph_compute(gf, &plan);
+
+    //if (n_past%100 == 0) {
+    //    ggml_graph_print   (&gf);
+    //    ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
+    //}
+
+    // in this case, the output tensor is the last one in the graph
+    struct ggml_tensor * inpL = gf->nodes[gf->n_nodes - 1];
+
+    //embd_w.resize(n_vocab*N);
+    //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+
+    // return result just for the last token
+    embd_w.resize(n_vocab);
+    memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+
+    return true;
+}
+
+int main(int argc, char ** argv) {
+    ggml_time_init();
+
+    const int64_t t_main_start_us = ggml_time_us();
+
+    gpt_params params;
+    params.model = "models/gpt-2-117M/ggml-model.bin";
+
+    if (gpt_params_parse(argc, argv, params) == false) {
+        return 1;
+    }
+
+    if (params.seed < 0) {
+        params.seed = time(NULL);
+    }
+
+    printf("%s: seed = %d\n", __func__, params.seed);
+
+    std::mt19937 rng(params.seed);
+    if (params.prompt.empty()) {
+        params.prompt = gpt_random_prompt(rng);
+    }
+
+    int64_t t_load_us = 0;
+
+    gpt_vocab vocab;
+    gpt2_model model;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!gpt2_model_load(params.model, model, vocab)) {
+            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+            return 1;
+        }
+
+        t_load_us = ggml_time_us() - t_start_us;
+
+        test_gpt_tokenizer(vocab, params.token_test);
+    }
+
+    // keep this buffer alive while evaluating the model
+    std::vector<uint8_t> compute_buffer;
+
+    struct ggml_allocr * allocr = NULL;
+    // allocate the compute buffer
+    {
+        allocr = ggml_allocr_new_measure(GGML_MEM_ALIGN);
+
+        // create the worst case graph for memory usage estimation
+        int n_tokens = std::min(model.hparams.n_ctx, params.n_batch);
+        int n_past = model.hparams.n_ctx - n_tokens;
+        struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, std::vector<gpt_vocab::id>(n_tokens, 0));
+
+        // compute the required memory
+        size_t mem_size = ggml_allocr_alloc_graph(allocr, gf) + GGML_MEM_ALIGN;
+
+        // recreate the allocator with the required memory
+        ggml_allocr_free(allocr);
+        compute_buffer.resize(mem_size);
+        allocr = ggml_allocr_new(compute_buffer.data(), mem_size, GGML_MEM_ALIGN);
+
+        fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0/1024.0);
+    }
+
+    int n_past = 0;
+
+    int64_t t_sample_us  = 0;
+    int64_t t_predict_us = 0;
+
+    std::vector<float> logits;
+
+    // tokenize the prompt
+    std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
+
+    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
+
+    printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
+    printf("%s: number of tokens in prompt = %zu, first 8 tokens: ", __func__, embd_inp.size());
+    for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) {
+        printf("%d ", embd_inp[i]);
+    }
+    printf("\n\n");
+
+    // submit the input prompt token-by-token
+    // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning
+    std::vector<gpt_vocab::id> embd;
+
+    for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
+        // predict
+        if (embd.size() > 0) {
+            const int64_t t_start_us = ggml_time_us();
+
+            if (!gpt2_eval(model, allocr, params.n_threads, n_past, embd, logits)) {
+                printf("Failed to predict\n");
+                return 1;
+            }
+
+            t_predict_us += ggml_time_us() - t_start_us;
+        }
+
+        n_past += embd.size();
+        embd.clear();
+
+        if (i >= embd_inp.size()) {
+            // sample next token
+            const int   top_k = params.top_k;
+            const float top_p = params.top_p;
+            const float temp  = params.temp;
+
+            const int n_vocab = model.hparams.n_vocab;
+
+            gpt_vocab::id id = 0;
+
+            {
+                const int64_t t_start_sample_us = ggml_time_us();
+
+                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
+
+                t_sample_us += ggml_time_us() - t_start_sample_us;
+            }
+
+            // add it to the context
+            embd.push_back(id);
+        } else {
+            // if here, it means we are still processing the input prompt
+            for (size_t k = i; k < embd_inp.size(); k++) {
+                embd.push_back(embd_inp[k]);
+                if (int32_t(embd.size()) >= params.n_batch) {
+                    break;
+                }
+            }
+            i += embd.size() - 1;
+        }
+
+        // display text
+        for (auto id : embd) {
+            printf("%s", vocab.id_to_token[id].c_str());
+        }
+        fflush(stdout);
+
+        // end of text token
+        if (embd.back() == 50256) {
+            break;
+        }
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n\n");
+        printf("%s:     load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
+        printf("%s:   sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
+        printf("%s:  predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+    }
+
+    ggml_free(model.ctx);
+
+    return 0;
+}
diff --git a/examples/gpt-2/main-backend.cpp b/examples/gpt-2/main-backend.cpp
new file mode 100644 (file)
index 0000000..ab43726
--- /dev/null
@@ -0,0 +1,1002 @@
+#include "ggml/ggml.h"
+#include "ggml/ggml-alloc.h"
+#include "ggml/ggml-backend.h"
+
+#ifdef GGML_USE_CUBLAS
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
+#include "common.h"
+#include "common-ggml.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
+    (void) level;
+    (void) user_data;
+    fputs(text, stderr);
+    fflush(stderr);
+}
+
+// default hparams (GPT-2 117M)
+struct gpt2_hparams {
+    int32_t n_vocab = 50257;
+    int32_t n_ctx   = 1024;
+    int32_t n_embd  = 768;
+    int32_t n_head  = 12;
+    int32_t n_layer = 12;
+    int32_t ftype   = 1;
+    float   eps     = 1e-5f;
+};
+
+struct gpt2_layer {
+    // normalization
+    struct ggml_tensor * ln_1_g;
+    struct ggml_tensor * ln_1_b;
+
+    struct ggml_tensor * ln_2_g;
+    struct ggml_tensor * ln_2_b;
+
+    // attention
+    struct ggml_tensor * c_attn_attn_w;
+    struct ggml_tensor * c_attn_attn_b;
+
+    struct ggml_tensor * c_attn_proj_w;
+    struct ggml_tensor * c_attn_proj_b;
+
+    // mlp
+    struct ggml_tensor * c_mlp_fc_w;
+    struct ggml_tensor * c_mlp_fc_b;
+
+    struct ggml_tensor * c_mlp_proj_w;
+    struct ggml_tensor * c_mlp_proj_b;
+};
+
+struct gpt2_model {
+    gpt2_hparams hparams;
+
+    // normalization
+    struct ggml_tensor * ln_f_g;
+    struct ggml_tensor * ln_f_b;
+
+    struct ggml_tensor * wte;     // position embedding
+    struct ggml_tensor * wpe;     //    token embedding
+    struct ggml_tensor * lm_head; // language model head
+
+    std::vector<gpt2_layer> layers;
+
+    // key + value memory
+    struct ggml_tensor * memory_k;
+    struct ggml_tensor * memory_v;
+
+    //
+    struct ggml_context * ctx;
+
+    ggml_backend_t backend = NULL;
+
+    ggml_backend_buffer_t buffer_w;
+    ggml_backend_buffer_t buffer_kv;
+
+    std::map<std::string, struct ggml_tensor *> tensors;
+};
+
+// load the model's weights from a file
+bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_ctx, int n_gpu_layers) {
+    printf("%s: loading model from '%s'\n", __func__, fname.c_str());
+
+    auto fin = std::ifstream(fname, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        fin.read((char *) &magic, sizeof(magic));
+        if (magic != GGML_FILE_MAGIC) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            return false;
+        }
+    }
+
+    // load hparams
+    {
+        auto & hparams = model.hparams;
+
+        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fin.read((char *) &hparams.ftype,   sizeof(hparams.ftype));
+
+        const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
+
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        printf("%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        printf("%s: n_head  = %d\n", __func__, hparams.n_head);
+        printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+        printf("%s: ftype   = %d\n", __func__, hparams.ftype);
+        printf("%s: qntvr   = %d\n", __func__, qntvr);
+
+        hparams.ftype %= GGML_QNT_VERSION_FACTOR;
+    }
+
+    // load vocab
+    {
+        int32_t n_vocab = 0;
+        fin.read((char *) &n_vocab, sizeof(n_vocab));
+
+        if (n_vocab != model.hparams.n_vocab) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
+            return false;
+        }
+
+        std::string word;
+        std::vector<char> buf(128);
+
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            fin.read((char *) &len, sizeof(len));
+
+            buf.resize(len);
+            fin.read((char *) buf.data(), len);
+            word.assign(buf.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+        }
+    }
+
+    // for the big tensors, we have the option to store the data in 16-bit floats or quantized
+    // in order to save memory and also to speed up the computation
+    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
+    if (wtype == GGML_TYPE_COUNT) {
+        fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
+                __func__, fname.c_str(), model.hparams.ftype);
+        return false;
+    }
+
+    auto & ctx = model.ctx;
+
+    size_t buffer_size = 0;
+
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        buffer_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
+        buffer_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
+
+        buffer_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // wte
+        buffer_size +=   n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
+        buffer_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // lm_head
+
+        buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
+        buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
+
+        buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
+        buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
+
+        buffer_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype));         // c_attn_attn_w
+        buffer_size += n_layer*(       3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
+
+        buffer_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype));           // c_attn_proj_w
+        buffer_size += n_layer*(       n_embd*ggml_type_sizef(GGML_TYPE_F32));   // c_attn_proj_b
+
+        buffer_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_fc_w
+        buffer_size += n_layer*(       4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
+
+        buffer_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_proj_w
+        buffer_size += n_layer*(         n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
+
+        buffer_size += (6 + 12*n_layer)*128; // alignment overhead
+
+        printf("%s: ggml tensor size    = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
+        printf("%s: backend buffer size = %6.2f MB\n", __func__, buffer_size/(1024.0*1024.0));
+    }
+
+    // create the ggml context
+    {
+        size_t n_tensors = 2 + 6 + 12*model.hparams.n_layer;
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ ggml_tensor_overhead() * n_tensors,
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ true,
+        };
+
+        model.ctx = ggml_init(params);
+        if (!model.ctx) {
+            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+            return false;
+        }
+    }
+
+    // initialize the backend
+#ifdef GGML_USE_CUBLAS
+    if (n_gpu_layers > 0) {
+        fprintf(stderr, "%s: using CUDA backend\n", __func__);
+        model.backend = ggml_backend_cuda_init();
+        if (!model.backend) {
+            fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+        }
+    }
+#endif
+
+#ifdef GGML_USE_METAL
+    if (n_gpu_layers > 0) {
+        fprintf(stderr, "%s: using Metal backend\n", __func__);
+        ggml_metal_log_set_callback(ggml_log_callback_default, nullptr);
+        model.backend = ggml_backend_metal_init();
+        if (!model.backend) {
+            fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
+        }
+    }
+#endif
+
+    if (!model.backend) {
+        // fallback to CPU backend
+        fprintf(stderr, "%s: using CPU backend\n", __func__);
+        model.backend = ggml_backend_cpu_init();
+    }
+
+    if (!model.backend) {
+        fprintf(stderr, "%s: ggml_backend_cpu_init() failed\n", __func__);
+        return false;
+    }
+
+    // allocate weights buffer
+    model.buffer_w = ggml_backend_alloc_buffer(model.backend, buffer_size);
+
+    // prepare memory for the weights
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        model.layers.resize(n_layer);
+
+        model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+        model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+        model.wte     = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+        model.wpe     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
+        model.lm_head = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+
+        // map by name
+        model.tensors["model/ln_f/g"] = model.ln_f_g;
+        model.tensors["model/ln_f/b"] = model.ln_f_b;
+
+        model.tensors["model/wte"]     = model.wte;
+        model.tensors["model/wpe"]     = model.wpe;
+        model.tensors["model/lm_head"] = model.lm_head;
+
+        for (int i = 0; i < n_layer; ++i) {
+            auto & layer = model.layers[i];
+
+            layer.ln_1_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_1_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.ln_2_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_2_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, 3*n_embd);
+            layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
+
+            layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd);
+            layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.c_mlp_fc_w    = ggml_new_tensor_2d(ctx, wtype,           n_embd, 4*n_embd);
+            layer.c_mlp_fc_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
+
+            layer.c_mlp_proj_w  = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);
+            layer.c_mlp_proj_b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            // map by name
+            model.tensors["model/h" + std::to_string(i) + "/ln_1/g"]        = layer.ln_1_g;
+            model.tensors["model/h" + std::to_string(i) + "/ln_1/b"]        = layer.ln_1_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/ln_2/g"]        = layer.ln_2_g;
+            model.tensors["model/h" + std::to_string(i) + "/ln_2/b"]        = layer.ln_2_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"]    = layer.c_mlp_fc_w;
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"]    = layer.c_mlp_fc_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"]  = layer.c_mlp_proj_w;
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"]  = layer.c_mlp_proj_b;
+        }
+    }
+
+    // override the default training context with the user-provided
+    model.hparams.n_ctx = n_ctx;
+
+    // key + value memory
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+
+        const int n_mem      = n_layer*n_ctx;
+        const int n_elements = n_embd*n_mem;
+
+        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+
+        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
+
+        printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
+
+        // create a backend buffer (can be in host or device memory)
+        model.buffer_kv = ggml_backend_alloc_buffer(model.backend, memory_size + 256);
+
+        // allocate the tensors into the backend buffer
+        {
+            ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer_kv);
+
+            // this updates the pointers in the tensors to point to the correct location in the buffer
+            // this is necessary since the ggml_context is .no_alloc == true
+            // note that the buffer can actually be a device buffer, depending on the backend
+            ggml_allocr_alloc(alloc, model.memory_k);
+            ggml_allocr_alloc(alloc, model.memory_v);
+
+            ggml_allocr_free(alloc);
+        }
+    }
+
+    // load weights
+    {
+        ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer_w);
+
+        size_t total_size = 0;
+
+        bool has_lm_head = false;
+
+        std::vector<char> read_buf;
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ttype;
+
+            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+            fin.read(reinterpret_cast<char *>(&ttype),  sizeof(ttype));
+
+            if (fin.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = { 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            fin.read(&name[0], length);
+
+            if (model.tensors.find(name) == model.tensors.end()) {
+                fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
+                return false;
+            }
+
+            auto tensor = model.tensors[name];
+            ggml_set_name(tensor, name.c_str());
+            if (ggml_nelements(tensor) != nelements) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str());
+                return false;
+            }
+
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+                fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                        __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
+                return false;
+            }
+
+            // for debugging
+            if (0) {
+                printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
+            }
+
+            const size_t bpe = ggml_type_size(ggml_type(ttype));
+
+            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                        __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);
+                return false;
+            }
+
+            ggml_allocr_alloc(alloc, tensor);
+
+            if (ggml_backend_is_cpu  (model.backend)
+#ifdef GGML_USE_METAL
+                || ggml_backend_is_metal(model.backend)
+#endif
+                ) {
+                // for the CPU and Metal backend, we can read directly into the tensor
+                fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+            } else {
+                // read into a temporary buffer first, then copy to device memory
+                read_buf.resize(ggml_nbytes(tensor));
+                fin.read(read_buf.data(), ggml_nbytes(tensor));
+                ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
+            }
+
+            // GPT-2 models share the WTE tensor as the LM head
+            if (name == "model/wte" && has_lm_head == false) {
+                //ggml_allocr_alloc(alloc, model.lm_head);
+                //ggml_backend_tensor_copy(tensor, model.lm_head);
+                model.lm_head = tensor;
+            }
+
+            if (name == "model/lm_head") {
+                has_lm_head = true;
+            }
+
+            total_size += ggml_nbytes(tensor);
+        }
+
+        ggml_allocr_free(alloc);
+        printf("%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
+    }
+
+    fin.close();
+
+    return true;
+}
+
+// build the computation graph
+struct ggml_cgraph * gpt2_graph(
+        const gpt2_model & model,
+        struct ggml_allocr * allocr,
+        const int n_past,
+        const std::vector<gpt_vocab::id> & embd_inp) {
+    const int N = embd_inp.size();
+
+    const auto & hparams = model.hparams;
+
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+    const int n_ctx   = hparams.n_ctx;
+    const int n_head  = hparams.n_head;
+
+    // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
+    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
+    static std::vector<uint8_t> buf(buf_size);
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ buf_size,
+        /*.mem_buffer =*/ buf.data(),
+        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    ggml_allocr_alloc(allocr, embd);
+
+    // avoid writing to tensors if we are only measuring the memory usage
+    if (!ggml_allocr_is_measure(allocr)) {
+        ggml_backend_tensor_set(embd, embd_inp.data(), 0, N*ggml_element_size(embd));
+    }
+
+    struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    ggml_allocr_alloc(allocr, position);
+    if (!ggml_allocr_is_measure(allocr)) {
+        for (int i = 0; i < N; ++i) {
+            int32_t v = n_past + i;
+            ggml_backend_tensor_set(position, &v, i*sizeof(int32_t), sizeof(v));
+        }
+    }
+
+    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+    ggml_allocr_alloc(allocr, KQ_scale);
+    if (!ggml_allocr_is_measure(allocr)) {
+        float s = 1.0f/sqrtf(float(n_embd)/n_head);
+        ggml_backend_tensor_set(KQ_scale, &s, 0, sizeof(s));
+    }
+
+    // wte + wpe
+    struct ggml_tensor * inpL =
+        ggml_add(ctx0,
+                ggml_get_rows(ctx0, model.wte, embd),
+                ggml_get_rows(ctx0, model.wpe, position));
+
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * cur;
+
+        // norm
+        {
+            // [ 768, N]
+            cur = ggml_norm(ctx0, inpL, hparams.eps);
+
+            // cur = ln_1_g*cur + ln_1_b
+            // [ 768, N]
+            cur = ggml_add(ctx0,
+                    ggml_mul(ctx0,
+                        cur,
+                        model.layers[il].ln_1_g),
+                    model.layers[il].ln_1_b);
+        }
+
+        // attn
+        // [2304, 768] - model.layers[il].c_attn_attn_w
+        // [2304,   1] - model.layers[il].c_attn_attn_b
+        // [ 768,   N] - cur (in)
+        // [2304,   N] - cur (out)
+        //
+        // cur = attn_w*cur + attn_b
+        // [2304, N]
+        {
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_attn_attn_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    cur,
+                    model.layers[il].c_attn_attn_b);
+        }
+
+        // self-attention
+        {
+            struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
+            struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
+            struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
+
+            // store key and value to memory
+            if (N >= 1) {
+                struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
+
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
+            // [64, N, 12]
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
+                            Qcur,
+                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
+                        0, 2, 1, 3);
+
+            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
+            // [64, n_past + N, 12]
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
+                            n_embd/n_head, n_head, n_past + N),
+                        0, 2, 1, 3);
+
+            // GG: flash attention
+            //struct ggml_tensor * V =
+            //    ggml_cpy(ctx0,
+            //            ggml_permute(ctx0,
+            //                ggml_reshape_3d(ctx0,
+            //                    ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+            //                    n_embd/n_head, n_head, n_past + N),
+            //                1, 2, 0, 3),
+            //            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
+
+            //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
+
+            // K * Q
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctx0,
+                        KQ,
+                        KQ_scale);
+
+            // KQ_masked = mask_past(KQ_scaled)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+
+            // KQ = soft_max(KQ_masked)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
+            // [n_past + N, 64, 12]
+            struct ggml_tensor * V_trans =
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+                                n_embd/n_head, n_head, n_past + N),
+                            1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
+
+            // KQV = transpose(V) * KQ_soft_max
+            // [64, N, 12]
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            // [64, 12, N]
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            // [768, N]
+            cur = ggml_cpy(ctx0,
+                    KQV_merged,
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+        }
+
+        // projection
+        // [ 768, 768] - model.layers[il].c_attn_proj_w
+        // [ 768,   1] - model.layers[il].c_attn_proj_b
+        // [ 768,   N] - cur (in)
+        // [ 768,   N] - cur (out)
+        //
+        // cur = proj_w*cur + proj_b
+        // [768, N]
+        {
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_attn_proj_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    cur,
+                    model.layers[il].c_attn_proj_b);
+        }
+
+        // add the input
+        cur = ggml_add(ctx0, cur, inpL);
+
+        struct ggml_tensor * inpFF = cur;
+
+        // feed-forward network
+        {
+            // norm
+            {
+                cur = ggml_norm(ctx0, inpFF, hparams.eps);
+
+                // cur = ln_2_g*cur + ln_2_b
+                // [ 768, N]
+                cur = ggml_add(ctx0,
+                        ggml_mul(ctx0,
+                            cur,
+                            model.layers[il].ln_2_g),
+                        model.layers[il].ln_2_b);
+            }
+
+            // fully connected
+            // [3072, 768] - model.layers[il].c_mlp_fc_w
+            // [3072,   1] - model.layers[il].c_mlp_fc_b
+            // [ 768,   N] - cur (in)
+            // [3072,   N] - cur (out)
+            //
+            // cur = fc_w*cur + fc_b
+            // [3072, N]
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_mlp_fc_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    cur,
+                    model.layers[il].c_mlp_fc_b);
+
+            // GELU activation
+            // [3072, N]
+            cur = ggml_gelu(ctx0, cur);
+
+            // projection
+            // [ 768, 3072] - model.layers[il].c_mlp_proj_w
+            // [ 768,    1] - model.layers[il].c_mlp_proj_b
+            // [3072,    N] - cur (in)
+            // [ 768,    N] - cur (out)
+            //
+            // cur = proj_w*cur + proj_b
+            // [768, N]
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_mlp_proj_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    cur,
+                    model.layers[il].c_mlp_proj_b);
+        }
+
+        // input for next layer
+        inpL = ggml_add(ctx0, cur, inpFF);
+    }
+
+    // norm
+    {
+        // [ 768, N]
+        inpL = ggml_norm(ctx0, inpL, hparams.eps);
+
+        // inpL = ln_f_g*inpL + ln_f_b
+        // [ 768, N]
+        inpL = ggml_add(ctx0,
+                ggml_mul(ctx0,
+                    inpL,
+                    model.ln_f_g),
+                model.ln_f_b);
+    }
+
+    // inpL = WTE * inpL
+    // [ 768, 50257] - model.lm_head
+    // [ 768, N]     - inpL
+    inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
+
+    // logits -> probs
+    //inpL = ggml_soft_max(ctx0, inpL);
+
+    ggml_build_forward_expand(gf, inpL);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
+// evaluate the transformer
+//
+//   - model:     the model
+//   - allocr:    ggml_allocr to use to allocate the compute buffer
+//   - n_threads: number of threads to use
+//   - n_past:    the context size so far
+//   - embd_inp:  the embeddings of the tokens in the context
+//   - embd_w:    the predicted logits for the next token
+//
+bool gpt2_eval(
+        const gpt2_model & model,
+        struct ggml_allocr * allocr,
+        const int n_threads,
+        const int n_past,
+        const std::vector<gpt_vocab::id> & embd_inp,
+              std::vector<float>         & embd_w) {
+    const int N = embd_inp.size();
+
+    const auto & hparams = model.hparams;
+
+    const int n_vocab = hparams.n_vocab;
+
+    // reset the allocator to free all the memory allocated during the previous inference
+    ggml_allocr_reset(allocr);
+
+    struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, embd_inp);
+
+    // allocate tensors
+    ggml_allocr_alloc_graph(allocr, gf);
+
+    // run the computation
+    if (ggml_backend_is_cpu(model.backend)) {
+        ggml_backend_cpu_set_n_threads(model.backend, n_threads);
+    }
+#ifdef GGML_USE_METAL
+    if (ggml_backend_is_metal(model.backend)) {
+        ggml_backend_metal_set_n_cb(model.backend, n_threads);
+    }
+#endif
+    ggml_backend_graph_compute(model.backend, gf);
+
+    //if (n_past%100 == 0) {
+    //    ggml_graph_print   (&gf);
+    //    ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
+    //}
+
+    // in this case, the output tensor is the last one in the graph
+    struct ggml_tensor * inpL = gf->nodes[gf->n_nodes - 1];
+
+    //embd_w.resize(n_vocab*N);
+    //ggml_backend_tensor_get(inpL, embd_w.data(), 0, sizeof(float)*n_vocab*N);
+
+    // return result just for the last token
+    embd_w.resize(n_vocab);
+    ggml_backend_tensor_get(inpL, embd_w.data(), (n_vocab*(N-1))*sizeof(float), sizeof(float)*n_vocab);
+
+    return true;
+}
+
+int main(int argc, char ** argv) {
+    ggml_time_init();
+
+    const int64_t t_main_start_us = ggml_time_us();
+
+    gpt_params params;
+    params.model = "models/gpt-2-117M/ggml-model.bin";
+
+    if (gpt_params_parse(argc, argv, params) == false) {
+        return 1;
+    }
+
+    if (params.seed < 0) {
+        params.seed = time(NULL);
+    }
+
+    printf("%s: seed = %d\n", __func__, params.seed);
+
+    std::mt19937 rng(params.seed);
+    if (params.prompt.empty()) {
+        params.prompt = gpt_random_prompt(rng);
+    }
+
+    int64_t t_load_us = 0;
+
+    gpt_vocab vocab;
+    gpt2_model model;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!gpt2_model_load(params.model, model, vocab, params.n_ctx, params.n_gpu_layers)) {
+            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+            return 1;
+        }
+
+        t_load_us = ggml_time_us() - t_start_us;
+
+        test_gpt_tokenizer(vocab, params.token_test);
+    }
+
+    // keep this buffer alive while evaluating the model
+    ggml_backend_buffer_t buf_compute;
+
+    struct ggml_allocr * allocr = NULL;
+    // allocate the compute buffer
+    {
+         // alignment required by the backend
+        size_t align = ggml_backend_get_alignment(model.backend);
+        allocr = ggml_allocr_new_measure(align);
+
+        // create the worst case graph for memory usage estimation
+        int n_tokens = std::min(model.hparams.n_ctx, params.n_batch);
+        int n_past = model.hparams.n_ctx - n_tokens;
+        struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, std::vector<gpt_vocab::id>(n_tokens, 0));
+
+        // compute the required memory
+        size_t mem_size = ggml_allocr_alloc_graph(allocr, gf);
+
+        // recreate the allocator with the required memory
+        ggml_allocr_free(allocr);
+        buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size);
+        allocr = ggml_allocr_new_from_buffer(buf_compute);
+
+        fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0/1024.0);
+    }
+
+    int n_past = 0;
+
+    int64_t t_sample_us  = 0;
+    int64_t t_predict_us = 0;
+
+    std::vector<float> logits;
+
+    // tokenize the prompt
+    std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
+
+    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
+
+    printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
+    printf("%s: number of tokens in prompt = %zu, first 8 tokens: ", __func__, embd_inp.size());
+    for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) {
+        printf("%d ", embd_inp[i]);
+    }
+    printf("\n\n");
+
+    // submit the input prompt token-by-token
+    // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning
+    std::vector<gpt_vocab::id> embd;
+
+    for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
+        // predict
+        if (embd.size() > 0) {
+            const int64_t t_start_us = ggml_time_us();
+
+            if (!gpt2_eval(model, allocr, params.n_threads, n_past, embd, logits)) {
+                printf("Failed to predict\n");
+                return 1;
+            }
+
+            t_predict_us += ggml_time_us() - t_start_us;
+        }
+
+        n_past += embd.size();
+        embd.clear();
+
+        if (i >= embd_inp.size()) {
+            // sample next token
+            const int   top_k = params.top_k;
+            const float top_p = params.top_p;
+            const float temp  = params.temp;
+
+            const int n_vocab = model.hparams.n_vocab;
+
+            gpt_vocab::id id = 0;
+
+            {
+                const int64_t t_start_sample_us = ggml_time_us();
+
+                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
+
+                t_sample_us += ggml_time_us() - t_start_sample_us;
+            }
+
+            // add it to the context
+            embd.push_back(id);
+        } else {
+            // if here, it means we are still processing the input prompt
+            for (size_t k = i; k < embd_inp.size(); k++) {
+                embd.push_back(embd_inp[k]);
+                if (int32_t(embd.size()) >= params.n_batch) {
+                    break;
+                }
+            }
+            i += embd.size() - 1;
+        }
+
+        // display text
+        for (auto id : embd) {
+            printf("%s", vocab.id_to_token[id].c_str());
+        }
+        fflush(stdout);
+
+        // end of text token
+        if (!params.ignore_eos && embd.back() == 50256) {
+            break;
+        }
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n\n");
+        printf("%s:     load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
+        printf("%s:   sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
+        printf("%s:  predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+    }
+
+    ggml_free(model.ctx);
+
+    ggml_backend_buffer_free(model.buffer_w);
+    ggml_backend_buffer_free(model.buffer_kv);
+    ggml_backend_buffer_free(buf_compute);
+    ggml_backend_free(model.backend);
+
+    return 0;
+}
index 3ba665da2da2a0c35651e61627aeb7dac5601790..99167db99b68dee077c263619d0d4aba275a3f6a 100644 (file)
@@ -551,7 +551,7 @@ struct ggml_cgraph * gpt2_graph(
     const int32_t kv_head  = ggml_allocr_is_measure(allocr) ? n_ctx - n_tokens : kv_cache.head;
 
     // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
-    static size_t buf_size = ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead();
+    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
     static std::vector<uint8_t> buf(buf_size);
 
     struct ggml_init_params params = {
diff --git a/examples/gpt-2/main-ctx.cpp b/examples/gpt-2/main-ctx.cpp
new file mode 100644 (file)
index 0000000..0eb71a5
--- /dev/null
@@ -0,0 +1,844 @@
+#include "ggml/ggml.h"
+
+#include "common.h"
+#include "common-ggml.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+// default hparams (GPT-2 117M)
+struct gpt2_hparams {
+    int32_t n_vocab = 50257;
+    int32_t n_ctx   = 1024;
+    int32_t n_embd  = 768;
+    int32_t n_head  = 12;
+    int32_t n_layer = 12;
+    int32_t ftype   = 1;
+    float   eps     = 1e-5f;
+};
+
+struct gpt2_layer {
+    // normalization
+    struct ggml_tensor * ln_1_g;
+    struct ggml_tensor * ln_1_b;
+
+    struct ggml_tensor * ln_2_g;
+    struct ggml_tensor * ln_2_b;
+
+    // attention
+    struct ggml_tensor * c_attn_attn_w;
+    struct ggml_tensor * c_attn_attn_b;
+
+    struct ggml_tensor * c_attn_proj_w;
+    struct ggml_tensor * c_attn_proj_b;
+
+    // mlp
+    struct ggml_tensor * c_mlp_fc_w;
+    struct ggml_tensor * c_mlp_fc_b;
+
+    struct ggml_tensor * c_mlp_proj_w;
+    struct ggml_tensor * c_mlp_proj_b;
+};
+
+struct gpt2_model {
+    gpt2_hparams hparams;
+
+    // normalization
+    struct ggml_tensor * ln_f_g;
+    struct ggml_tensor * ln_f_b;
+
+    struct ggml_tensor * wte;     // position embedding
+    struct ggml_tensor * wpe;     //    token embedding
+    struct ggml_tensor * lm_head; // language model head
+
+    std::vector<gpt2_layer> layers;
+
+    // key + value memory
+    struct ggml_tensor * memory_k;
+    struct ggml_tensor * memory_v;
+
+    //
+    struct ggml_context * ctx;
+    std::map<std::string, struct ggml_tensor *> tensors;
+};
+
+// load the model's weights from a file
+bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
+    printf("%s: loading model from '%s'\n", __func__, fname.c_str());
+
+    auto fin = std::ifstream(fname, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        fin.read((char *) &magic, sizeof(magic));
+        if (magic != GGML_FILE_MAGIC) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            return false;
+        }
+    }
+
+    // load hparams
+    {
+        auto & hparams = model.hparams;
+
+        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fin.read((char *) &hparams.ftype,   sizeof(hparams.ftype));
+
+        const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
+
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        printf("%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        printf("%s: n_head  = %d\n", __func__, hparams.n_head);
+        printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+        printf("%s: ftype   = %d\n", __func__, hparams.ftype);
+        printf("%s: qntvr   = %d\n", __func__, qntvr);
+
+        hparams.ftype %= GGML_QNT_VERSION_FACTOR;
+    }
+
+    // load vocab
+    {
+        int32_t n_vocab = 0;
+        fin.read((char *) &n_vocab, sizeof(n_vocab));
+
+        if (n_vocab != model.hparams.n_vocab) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
+            return false;
+        }
+
+        std::string word;
+        std::vector<char> buf(128);
+
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            fin.read((char *) &len, sizeof(len));
+
+            buf.resize(len);
+            fin.read((char *) buf.data(), len);
+            word.assign(buf.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+        }
+    }
+
+    // for the big tensors, we have the option to store the data in 16-bit floats or quantized
+    // in order to save memory and also to speed up the computation
+    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
+    if (wtype == GGML_TYPE_COUNT) {
+        fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
+                __func__, fname.c_str(), model.hparams.ftype);
+        return false;
+    }
+
+    auto & ctx = model.ctx;
+
+    size_t ctx_size = 0;
+
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
+        ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
+
+        ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // wte
+        ctx_size +=   n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
+        ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // lm_head
+
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
+
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
+
+        ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype));         // c_attn_attn_w
+        ctx_size += n_layer*(       3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
+
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype));           // c_attn_proj_w
+        ctx_size += n_layer*(       n_embd*ggml_type_sizef(GGML_TYPE_F32));   // c_attn_proj_b
+
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_fc_w
+        ctx_size += n_layer*(       4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
+
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_proj_w
+        ctx_size += n_layer*(         n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
+
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
+
+        ctx_size += (6 + 12*n_layer)*512; // object overhead
+
+        printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
+        printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+    }
+
+    // create the ggml context
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ ctx_size,
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ false,
+        };
+
+        model.ctx = ggml_init(params);
+        if (!model.ctx) {
+            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+            return false;
+        }
+    }
+
+    // prepare memory for the weights
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        model.layers.resize(n_layer);
+
+        model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+        model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+        model.wte     = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+        model.wpe     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
+        model.lm_head = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+
+        // map by name
+        model.tensors["model/ln_f/g"] = model.ln_f_g;
+        model.tensors["model/ln_f/b"] = model.ln_f_b;
+
+        model.tensors["model/wte"]     = model.wte;
+        model.tensors["model/wpe"]     = model.wpe;
+        model.tensors["model/lm_head"] = model.lm_head;
+
+        for (int i = 0; i < n_layer; ++i) {
+            auto & layer = model.layers[i];
+
+            layer.ln_1_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_1_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.ln_2_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_2_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, 3*n_embd);
+            layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
+
+            layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd);
+            layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.c_mlp_fc_w    = ggml_new_tensor_2d(ctx, wtype,           n_embd, 4*n_embd);
+            layer.c_mlp_fc_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
+
+            layer.c_mlp_proj_w  = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);
+            layer.c_mlp_proj_b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            // map by name
+            model.tensors["model/h" + std::to_string(i) + "/ln_1/g"]        = layer.ln_1_g;
+            model.tensors["model/h" + std::to_string(i) + "/ln_1/b"]        = layer.ln_1_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/ln_2/g"]        = layer.ln_2_g;
+            model.tensors["model/h" + std::to_string(i) + "/ln_2/b"]        = layer.ln_2_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"]    = layer.c_mlp_fc_w;
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"]    = layer.c_mlp_fc_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"]  = layer.c_mlp_proj_w;
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"]  = layer.c_mlp_proj_b;
+        }
+    }
+
+    // key + value memory
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+
+        const int n_mem      = n_layer*n_ctx;
+        const int n_elements = n_embd*n_mem;
+
+        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+
+        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
+
+        printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
+    }
+
+    // load weights
+    {
+        size_t total_size = 0;
+
+        bool has_lm_head = false;
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ttype;
+
+            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+            fin.read(reinterpret_cast<char *>(&ttype),  sizeof(ttype));
+
+            if (fin.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = { 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            fin.read(&name[0], length);
+
+            if (model.tensors.find(name) == model.tensors.end()) {
+                fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
+                return false;
+            }
+
+            auto tensor = model.tensors[name];
+            if (ggml_nelements(tensor) != nelements) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str());
+                return false;
+            }
+
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+                fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                        __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
+                return false;
+            }
+
+            // for debugging
+            if (0) {
+                printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
+            }
+
+            const size_t bpe = ggml_type_size(ggml_type(ttype));
+
+            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                        __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);
+                return false;
+            }
+
+            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+
+            // GPT-2 models share the WTE tensor as the LM head
+            if (name == "model/wte" && has_lm_head == false) {
+                memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
+            }
+
+            if (name == "model/lm_head") {
+                has_lm_head = true;
+            }
+
+            total_size += ggml_nbytes(tensor);
+        }
+
+        printf("%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
+    }
+
+    fin.close();
+
+    return true;
+}
+
+// evaluate the transformer
+//
+//   - model:     the model
+//   - n_threads: number of threads to use
+//   - n_past:    the context size so far
+//   - embd_inp:  the embeddings of the tokens in the context
+//   - embd_w:    the predicted logits for the next token
+//
+bool gpt2_eval(
+        const gpt2_model & model,
+        const int n_threads,
+        const int n_past,
+        const std::vector<gpt_vocab::id> & embd_inp,
+              std::vector<float>         & embd_w,
+              size_t                     & mem_per_token) {
+    const int N = embd_inp.size();
+
+    const auto & hparams = model.hparams;
+
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+    const int n_ctx   = hparams.n_ctx;
+    const int n_head  = hparams.n_head;
+    const int n_vocab = hparams.n_vocab;
+
+    static size_t buf_size = 256u*1024*1024;
+    static void * buf = malloc(buf_size);
+
+    if (mem_per_token > 0 && mem_per_token*N > buf_size) {
+        const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
+        //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
+
+        // reallocate
+        buf_size = buf_size_new;
+        buf = realloc(buf, buf_size);
+        if (buf == nullptr) {
+            fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
+            return false;
+        }
+    }
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ buf_size,
+        /*.mem_buffer =*/ buf,
+        /*.no_alloc   =*/ false,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
+
+    struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    for (int i = 0; i < N; ++i) {
+        ((int32_t *) position->data)[i] = n_past + i;
+    }
+
+    // wte + wpe
+    struct ggml_tensor * inpL =
+        ggml_add(ctx0,
+                ggml_get_rows(ctx0, model.wte, embd),
+                ggml_get_rows(ctx0, model.wpe, position));
+
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * cur;
+
+        // norm
+        {
+            // [ 768, N]
+            cur = ggml_norm(ctx0, inpL, hparams.eps);
+
+            // cur = ln_1_g*cur + ln_1_b
+            // [ 768, N]
+            cur = ggml_add(ctx0,
+                    ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
+                        cur),
+                    ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
+        }
+
+        // attn
+        // [2304, 768] - model.layers[il].c_attn_attn_w
+        // [2304,   1] - model.layers[il].c_attn_attn_b
+        // [ 768,   N] - cur (in)
+        // [2304,   N] - cur (out)
+        //
+        // cur = attn_w*cur + attn_b
+        // [2304, N]
+        {
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_attn_attn_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
+                    cur);
+        }
+
+        // self-attention
+        {
+            struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
+            struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
+            struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
+
+            // store key and value to memory
+            if (N >= 1) {
+                struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
+
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
+            // [64, N, 12]
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
+                            Qcur,
+                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
+                        0, 2, 1, 3);
+
+            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
+            // [64, n_past + N, 12]
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
+                            n_embd/n_head, n_head, n_past + N),
+                        0, 2, 1, 3);
+
+            // GG: flash attention
+            //struct ggml_tensor * V =
+            //    ggml_cpy(ctx0,
+            //            ggml_permute(ctx0,
+            //                ggml_reshape_3d(ctx0,
+            //                    ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+            //                    n_embd/n_head, n_head, n_past + N),
+            //                1, 2, 0, 3),
+            //            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
+
+            //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
+
+            // K * Q
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale_inplace(ctx0,
+                        KQ,
+                        ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
+                        );
+
+            // KQ_masked = mask_past(KQ_scaled)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
+
+            // KQ = soft_max(KQ_masked)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
+
+            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
+            // [n_past + N, 64, 12]
+            struct ggml_tensor * V_trans =
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+                                n_embd/n_head, n_head, n_past + N),
+                            1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
+
+            // KQV = transpose(V) * KQ_soft_max
+            // [64, N, 12]
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            // [64, 12, N]
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            // [768, N]
+            cur = ggml_cpy(ctx0,
+                    KQV_merged,
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+        }
+
+        // projection
+        // [ 768, 768] - model.layers[il].c_attn_proj_w
+        // [ 768,   1] - model.layers[il].c_attn_proj_b
+        // [ 768,   N] - cur (in)
+        // [ 768,   N] - cur (out)
+        //
+        // cur = proj_w*cur + proj_b
+        // [768, N]
+        {
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_attn_proj_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
+                    cur);
+        }
+
+        // add the input
+        cur = ggml_add(ctx0, cur, inpL);
+
+        struct ggml_tensor * inpFF = cur;
+
+        // feed-forward network
+        {
+            // norm
+            {
+                cur = ggml_norm(ctx0, inpFF, hparams.eps);
+
+                // cur = ln_2_g*cur + ln_2_b
+                // [ 768, N]
+                cur = ggml_add(ctx0,
+                        ggml_mul(ctx0,
+                            ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
+                            cur),
+                        ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
+            }
+
+            // fully connected
+            // [3072, 768] - model.layers[il].c_mlp_fc_w
+            // [3072,   1] - model.layers[il].c_mlp_fc_b
+            // [ 768,   N] - cur (in)
+            // [3072,   N] - cur (out)
+            //
+            // cur = fc_w*cur + fc_b
+            // [3072, N]
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_mlp_fc_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
+                    cur);
+
+            // GELU activation
+            // [3072, N]
+            cur = ggml_gelu(ctx0, cur);
+
+            // projection
+            // [ 768, 3072] - model.layers[il].c_mlp_proj_w
+            // [ 768,    1] - model.layers[il].c_mlp_proj_b
+            // [3072,    N] - cur (in)
+            // [ 768,    N] - cur (out)
+            //
+            // cur = proj_w*cur + proj_b
+            // [768, N]
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_mlp_proj_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
+                    cur);
+        }
+
+        // input for next layer
+        inpL = ggml_add(ctx0, cur, inpFF);
+    }
+
+    // norm
+    {
+        // [ 768, N]
+        inpL = ggml_norm(ctx0, inpL, hparams.eps);
+
+        // inpL = ln_f_g*inpL + ln_f_b
+        // [ 768, N]
+        inpL = ggml_add(ctx0,
+                ggml_mul(ctx0,
+                    ggml_repeat(ctx0, model.ln_f_g, inpL),
+                    inpL),
+                ggml_repeat(ctx0, model.ln_f_b, inpL));
+    }
+
+    // inpL = WTE * inpL
+    // [ 768, 50257] - model.lm_head
+    // [ 768, N]     - inpL
+    inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
+
+    // logits -> probs
+    //inpL = ggml_soft_max_inplace(ctx0, inpL);
+
+    // run the computation
+    ggml_build_forward_expand(gf, inpL);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
+
+    //if (n_past%100 == 0) {
+    //    ggml_graph_print   (&gf);
+    //    ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
+    //}
+
+    //embd_w.resize(n_vocab*N);
+    //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+
+    // return result just for the last token
+    embd_w.resize(n_vocab);
+    memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+
+    if (mem_per_token == 0) {
+        mem_per_token = ggml_used_mem(ctx0)/N;
+    }
+    //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
+
+    ggml_free(ctx0);
+
+    return true;
+}
+
+int main(int argc, char ** argv) {
+    ggml_time_init();
+
+    const int64_t t_main_start_us = ggml_time_us();
+
+    gpt_params params;
+    params.model = "models/gpt-2-117M/ggml-model.bin";
+
+    if (gpt_params_parse(argc, argv, params) == false) {
+        return 1;
+    }
+
+    if (params.seed < 0) {
+        params.seed = time(NULL);
+    }
+
+    printf("%s: seed = %d\n", __func__, params.seed);
+
+    std::mt19937 rng(params.seed);
+    if (params.prompt.empty()) {
+        params.prompt = gpt_random_prompt(rng);
+    }
+
+    int64_t t_load_us = 0;
+
+    gpt_vocab vocab;
+    gpt2_model model;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!gpt2_model_load(params.model, model, vocab)) {
+            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+            return 1;
+        }
+
+        t_load_us = ggml_time_us() - t_start_us;
+
+        test_gpt_tokenizer(vocab, params.token_test);
+    }
+
+    int n_past = 0;
+
+    int64_t t_sample_us  = 0;
+    int64_t t_predict_us = 0;
+
+    std::vector<float> logits;
+
+    // tokenize the prompt
+    std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
+
+    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
+
+    printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
+    printf("%s: number of tokens in prompt = %zu, first 8 tokens: ", __func__, embd_inp.size());
+    for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) {
+        printf("%d ", embd_inp[i]);
+    }
+    printf("\n\n");
+
+    // submit the input prompt token-by-token
+    // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning
+    std::vector<gpt_vocab::id> embd;
+
+    // determine the required inference memory per token:
+    size_t mem_per_token = 0;
+    gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
+
+    for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
+        // predict
+        if (embd.size() > 0) {
+            const int64_t t_start_us = ggml_time_us();
+
+            if (!gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
+                printf("Failed to predict\n");
+                return 1;
+            }
+
+            t_predict_us += ggml_time_us() - t_start_us;
+        }
+
+        n_past += embd.size();
+        embd.clear();
+
+        if (i >= embd_inp.size()) {
+            // sample next token
+            const int   top_k = params.top_k;
+            const float top_p = params.top_p;
+            const float temp  = params.temp;
+
+            const int n_vocab = model.hparams.n_vocab;
+
+            gpt_vocab::id id = 0;
+
+            {
+                const int64_t t_start_sample_us = ggml_time_us();
+
+                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
+
+                t_sample_us += ggml_time_us() - t_start_sample_us;
+            }
+
+            // add it to the context
+            embd.push_back(id);
+        } else {
+            // if here, it means we are still processing the input prompt
+            for (size_t k = i; k < embd_inp.size(); k++) {
+                embd.push_back(embd_inp[k]);
+                if (int32_t(embd.size()) >= params.n_batch) {
+                    break;
+                }
+            }
+            i += embd.size() - 1;
+        }
+
+        // display text
+        for (auto id : embd) {
+            printf("%s", vocab.id_to_token[id].c_str());
+        }
+        fflush(stdout);
+
+        // end of text token
+        if (embd.back() == 50256) {
+            break;
+        }
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n\n");
+        printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
+        printf("%s:     load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
+        printf("%s:   sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
+        printf("%s:  predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+    }
+
+    ggml_free(model.ctx);
+
+    return 0;
+}
index 514f299a117767ce694e35b3db57d0c0fe9b68d2..1a249254d60a8a02eb576d8bbf9ed355f25a3f5d 100644 (file)
@@ -87,16 +87,57 @@ struct gpt2_model {
     //
     struct ggml_context * ctx;
 
-    ggml_backend_t backend = NULL;
-
-    ggml_backend_buffer_t buffer_w;
+    std::vector<ggml_backend_t> backends;
+    std::vector<ggml_backend_buffer_t> buffers_w;
     ggml_backend_buffer_t buffer_kv;
+    ggml_backend_buffer_t buffer_input;
 
     std::map<std::string, struct ggml_tensor *> tensors;
+
+    // inputs/constants
+    struct ggml_tensor * embd;
+    struct ggml_tensor * position;
+    struct ggml_tensor * KQ_scale;
 };
 
+void init_backends(gpt2_model & model, const gpt_params & params) {
+    ggml_backend_t gpu_backend = NULL;
+
+    // initialize the backends
+#ifdef GGML_USE_CUBLAS
+    if (params.n_gpu_layers > 0) {
+        fprintf(stderr, "%s: using CUDA backend\n", __func__);
+        gpu_backend = ggml_backend_cuda_init();
+        if (!gpu_backend) {
+            fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+        }
+    }
+#endif
+
+#ifdef GGML_USE_METAL
+    if (params.n_gpu_layers > 0) {
+        fprintf(stderr, "%s: using Metal backend\n", __func__);
+        ggml_metal_log_set_callback(ggml_log_callback_default, nullptr);
+        gpu_backend = ggml_backend_metal_init();
+        if (!gpu_backend) {
+            fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
+        } else {
+            ggml_backend_metal_set_n_cb(gpu_backend, params.n_threads);
+        }
+    }
+#endif
+    if (gpu_backend) {
+        model.backends.push_back(gpu_backend);
+    }
+
+    // always add the CPU backend as a fallback
+    ggml_backend_t cpu_backend = ggml_backend_cpu_init();
+    ggml_backend_cpu_set_n_threads(cpu_backend, params.n_threads);
+    model.backends.push_back(cpu_backend);
+}
+
 // load the model's weights from a file
-bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_ctx, int n_gpu_layers) {
+bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, const gpt_params & params) {
     printf("%s: loading model from '%s'\n", __func__, fname.c_str());
 
     auto fin = std::ifstream(fname, std::ios::binary);
@@ -177,50 +218,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
 
     auto & ctx = model.ctx;
 
-    size_t buffer_size = 0;
-
-    {
-        const auto & hparams = model.hparams;
-
-        const int n_embd  = hparams.n_embd;
-        const int n_layer = hparams.n_layer;
-        const int n_ctx   = hparams.n_ctx;
-        const int n_vocab = hparams.n_vocab;
-
-        buffer_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
-        buffer_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
-
-        buffer_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // wte
-        buffer_size +=   n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
-        buffer_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // lm_head
-
-        buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
-        buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
-
-        buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
-        buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
-
-        buffer_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype));         // c_attn_attn_w
-        buffer_size += n_layer*(       3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
-
-        buffer_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype));           // c_attn_proj_w
-        buffer_size += n_layer*(       n_embd*ggml_type_sizef(GGML_TYPE_F32));   // c_attn_proj_b
-
-        buffer_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_fc_w
-        buffer_size += n_layer*(       4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
-
-        buffer_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_proj_w
-        buffer_size += n_layer*(         n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
-
-        buffer_size += (6 + 12*n_layer)*128; // alignment overhead
-
-        printf("%s: ggml tensor size    = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
-        printf("%s: backend buffer size = %6.2f MB\n", __func__, buffer_size/(1024.0*1024.0));
-    }
-
     // create the ggml context
     {
-        size_t n_tensors = 2 + 6 + 12*model.hparams.n_layer;
+        size_t n_tensors = 3 /* input */ + 2 /* kv */ + 6 + 12*model.hparams.n_layer;
         struct ggml_init_params params = {
             /*.mem_size   =*/ ggml_tensor_overhead() * n_tensors,
             /*.mem_buffer =*/ NULL,
@@ -234,43 +234,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
         }
     }
 
-    // initialize the backend
-#ifdef GGML_USE_CUBLAS
-    if (n_gpu_layers > 0) {
-        fprintf(stderr, "%s: using CUDA backend\n", __func__);
-        model.backend = ggml_backend_cuda_init();
-        if (!model.backend) {
-            fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
-        }
-    }
-#endif
-
-#ifdef GGML_USE_METAL
-    if (n_gpu_layers > 0) {
-        fprintf(stderr, "%s: using Metal backend\n", __func__);
-        ggml_metal_log_set_callback(ggml_log_callback_default, nullptr);
-        model.backend = ggml_backend_metal_init();
-        if (!model.backend) {
-            fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
-        }
-    }
-#endif
-
-    if (!model.backend) {
-        // fallback to CPU backend
-        fprintf(stderr, "%s: using CPU backend\n", __func__);
-        model.backend = ggml_backend_cpu_init();
-    }
-
-    if (!model.backend) {
-        fprintf(stderr, "%s: ggml_backend_cpu_init() failed\n", __func__);
-        return false;
-    }
-
-    // allocate weights buffer
-    model.buffer_w = ggml_backend_alloc_buffer(model.backend, buffer_size);
-
-    // prepare memory for the weights
+    // create tensors for the weights
     {
         const auto & hparams = model.hparams;
 
@@ -338,10 +302,69 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
         }
     }
 
-    // override the default training context with the user-provided
-    model.hparams.n_ctx = n_ctx;
+    // assign tensors to backends
+    init_backends(model, params);
+    ggml_backend_t backend_gpu = model.backends.front();
+    ggml_backend_t backend_cpu = model.backends.back();
+    std::map<std::string, ggml_backend_t> tensor_backends;
+    {
+        const int i_gpu_first_layer = model.hparams.n_layer - params.n_gpu_layers;
+        for (auto it : model.tensors) {
+            const std::string & name = it.first;
+            // input tensors
+            if (name == "model/wte" || name == "model/wpe") {
+                if (params.n_gpu_layers > model.hparams.n_layer) {
+                    tensor_backends[name] = backend_gpu;
+                } else {
+                    tensor_backends[name] = backend_cpu;
+                }
+            }
+            // output tensors
+            if (name == "model/ln_f/g" || name == "model/ln_f/b" || name == "model/lm_head") {
+                if (params.n_gpu_layers > 0) {
+                    tensor_backends[name] = backend_gpu;
+                } else {
+                    tensor_backends[name] = backend_cpu;
+                }
+            }
+            // layer tensors
+            if (name.substr(0, 7) == "model/h") {
+                // parse layer number
+                int layer = std::stoi(name.substr(7, 2));
+                if (layer >= i_gpu_first_layer) {
+                    tensor_backends[name] = backend_gpu;
+                } else {
+                    tensor_backends[name] = backend_cpu;
+                }
+            }
+        }
+    }
 
-    // key + value memory
+    // allocate buffers
+    std::map<ggml_backend_t, std::unique_ptr<ggml_allocr, decltype(&ggml_allocr_free)>> backend_buffers;
+    for (auto backend : model.backends) {
+        // compute the size of the buffer
+        size_t size = 0;
+        for (auto it : model.tensors) {
+            if (tensor_backends[it.first] == backend) {
+                size += ggml_nbytes(it.second) + 512;
+            }
+        }
+        if (size > 0) {
+            printf("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(backend), size/1024.0/1024.0);
+            // allocate the buffer
+            ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, size);
+            model.buffers_w.push_back(buffer);
+
+            // create an allocator for the buffer to allocate the tensors
+            auto alloc = std::unique_ptr<ggml_allocr, decltype(&ggml_allocr_free)>(ggml_allocr_new_from_buffer(buffer), ggml_allocr_free);
+            backend_buffers.insert(std::make_pair(backend, std::move(alloc)));
+        } else {
+            model.buffers_w.push_back(NULL);
+        }
+    }
+
+    // allocate key + value memory
     {
         const auto & hparams = model.hparams;
 
@@ -355,12 +378,17 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
         model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
         model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
 
+        ggml_set_name(model.memory_k, "model/memory_k");
+        ggml_set_name(model.memory_v, "model/memory_v");
+
         const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
 
         printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
 
         // create a backend buffer (can be in host or device memory)
-        model.buffer_kv = ggml_backend_alloc_buffer(model.backend, memory_size + 256);
+        ggml_backend_t backend_kv = params.n_gpu_layers >= hparams.n_layer/2 ? backend_gpu : backend_cpu;
+        printf("%s: backend_kv = %s\n", __func__, ggml_backend_name(backend_kv));
+        model.buffer_kv = ggml_backend_alloc_buffer(backend_kv, memory_size + 512*2);
 
         // allocate the tensors into the backend buffer
         {
@@ -378,8 +406,6 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
 
     // load weights
     {
-        ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer_w);
-
         size_t total_size = 0;
 
         bool has_lm_head = false;
@@ -440,11 +466,15 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
                 return false;
             }
 
+            // allocate the tensor
+            ggml_backend_t backend = tensor_backends[name];
+            ggml_allocr * alloc = backend_buffers.find(backend)->second.get();
             ggml_allocr_alloc(alloc, tensor);
+            //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
 
-            if (ggml_backend_is_cpu  (model.backend)
+            if (ggml_backend_is_cpu(backend)
 #ifdef GGML_USE_METAL
-                || ggml_backend_is_metal(model.backend)
+                || ggml_backend_is_metal(backend)
 #endif
                 ) {
                 // for the CPU and Metal backend, we can read directly into the tensor
@@ -458,9 +488,10 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
 
             // GPT-2 models share the WTE tensor as the LM head
             if (name == "model/wte" && has_lm_head == false) {
-                //ggml_allocr_alloc(alloc, model.lm_head);
-                //ggml_backend_tensor_copy(tensor, model.lm_head);
-                model.lm_head = tensor;
+                ggml_allocr_alloc(backend_buffers.find(tensor_backends["model/lm_head"])->second.get(), model.lm_head);
+                //printf("%s: [%5.5s] %s (copied)\n", __func__, ggml_backend_name(tensor_backends["model/lm_head"]), "model/lm_head");
+                ggml_backend_tensor_copy(tensor, model.lm_head);
+                total_size += ggml_nbytes(model.lm_head);
             }
 
             if (name == "model/lm_head") {
@@ -469,20 +500,47 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
 
             total_size += ggml_nbytes(tensor);
         }
-
-        ggml_allocr_free(alloc);
         printf("%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
     }
 
     fin.close();
 
+    // allocate input tensors
+    {
+        model.embd = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, model.hparams.n_ctx);
+        model.position = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, model.hparams.n_ctx);
+        model.KQ_scale = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); // FIXME: should be in backend_kv, but also shouldn't matter
+
+        ggml_set_name(model.embd, "in/embd");
+        ggml_set_name(model.position, "in/position");
+        ggml_set_name(model.KQ_scale, "KQ_scale");
+
+        // add input tensors to cpu backend
+        size_t input_size = ggml_nbytes(model.embd) + ggml_nbytes(model.position) + ggml_nbytes(model.KQ_scale);
+
+        // FIXME: use cpu backend after sched impl
+        ggml_backend_t backend_input = params.n_gpu_layers >= model.hparams.n_layer ? backend_gpu : backend_cpu;
+        model.buffer_input = ggml_backend_alloc_buffer(backend_input, input_size + 512*3);
+        printf("%s: backend_in = %s (%zu bytes)\n", __func__, ggml_backend_name(backend_input), input_size);
+
+        // allocate the tensors into the backend buffer
+        ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer_input);
+        ggml_allocr_alloc(alloc, model.embd);
+        ggml_allocr_alloc(alloc, model.position);
+        ggml_allocr_alloc(alloc, model.KQ_scale);
+        ggml_allocr_free(alloc);
+
+        // initialize KQ_scale
+        float s = 1.0f/sqrtf(float(model.hparams.n_embd)/model.hparams.n_head);
+        ggml_backend_tensor_set(model.KQ_scale, &s, 0, sizeof(s));
+    }
+
     return true;
 }
 
 // build the computation graph
 struct ggml_cgraph * gpt2_graph(
         const gpt2_model & model,
-        struct ggml_allocr * allocr,
         const int n_past,
         const std::vector<gpt_vocab::id> & embd_inp) {
     const int N = embd_inp.size();
@@ -495,7 +553,7 @@ struct ggml_cgraph * gpt2_graph(
     const int n_head  = hparams.n_head;
 
     // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
-    static size_t buf_size = ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead();
+    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
     static std::vector<uint8_t> buf(buf_size);
 
     struct ggml_init_params params = {
@@ -508,35 +566,36 @@ struct ggml_cgraph * gpt2_graph(
 
     struct ggml_cgraph  * gf = ggml_new_graph(ctx0);
 
-    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
-    ggml_allocr_alloc(allocr, embd);
+    struct ggml_tensor * embd = ggml_view_1d(ctx0, model.embd, N, 0);
 
-    // avoid writing to tensors if we are only measuring the memory usage
-    if (!ggml_allocr_is_measure(allocr)) {
-        ggml_backend_tensor_set(embd, embd_inp.data(), 0, N*ggml_element_size(embd));
-    }
+    // TODO: avoid writing to tensors if we are only measuring the memory usage
+    // not critical, just a minor optimization
+
+    //if (!ggml_allocr_is_measure(allocr)) {
+        //ggml_backend_tensor_set(embd, embd_inp.data(), 0, N*ggml_element_size(embd));
+        ggml_backend_tensor_set(model.embd, embd_inp.data(), 0, N*ggml_element_size(embd)); // FIXME: cannot use the view here because it's not initialized yet (buffer not set), but we should
+    //}
+    //memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
 
-    struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
-    ggml_allocr_alloc(allocr, position);
-    if (!ggml_allocr_is_measure(allocr)) {
+    struct ggml_tensor * position = ggml_view_1d(ctx0, model.position, N, 0);
+    //if (!ggml_allocr_is_measure(allocr)) {
         for (int i = 0; i < N; ++i) {
             int32_t v = n_past + i;
-            ggml_backend_tensor_set(position, &v, i*sizeof(int32_t), sizeof(v));
+            ggml_backend_tensor_set(model.position, &v, i*sizeof(int32_t), sizeof(v)); // FIXME: same
+            //((int32_t *) position->data)[i] = n_past + i;
         }
-    }
+    //}
 
-    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-    ggml_allocr_alloc(allocr, KQ_scale);
-    if (!ggml_allocr_is_measure(allocr)) {
-        float s = 1.0f/sqrtf(float(n_embd)/n_head);
-        ggml_backend_tensor_set(KQ_scale, &s, 0, sizeof(s));
-    }
+    struct ggml_tensor * KQ_scale = model.KQ_scale;
 
     // wte + wpe
     struct ggml_tensor * inpL =
         ggml_add(ctx0,
                 ggml_get_rows(ctx0, model.wte, embd),
                 ggml_get_rows(ctx0, model.wpe, position));
+    ggml_set_name(inpL, "inpL");
+    ggml_set_name(inpL->src[0], "wte");
+    ggml_set_name(inpL->src[1], "wpe");
 
     for (int il = 0; il < n_layer; ++il) {
         struct ggml_tensor * cur;
@@ -545,6 +604,7 @@ struct ggml_cgraph * gpt2_graph(
         {
             // [ 768, N]
             cur = ggml_norm(ctx0, inpL, hparams.eps);
+            ggml_format_name(cur, "l%d.norm", il);
 
             // cur = ln_1_g*cur + ln_1_b
             // [ 768, N]
@@ -553,6 +613,8 @@ struct ggml_cgraph * gpt2_graph(
                         cur,
                         model.layers[il].ln_1_g),
                     model.layers[il].ln_1_b);
+            ggml_format_name(cur, "l%d.ln_1_b", il);
+            ggml_format_name(cur->src[0], "l%d.ln_1_g", il);
         }
 
         // attn
@@ -567,10 +629,12 @@ struct ggml_cgraph * gpt2_graph(
             cur = ggml_mul_mat(ctx0,
                     model.layers[il].c_attn_attn_w,
                     cur);
+            ggml_format_name(cur, "l%d.attn_w", il);
 
             cur = ggml_add(ctx0,
                     cur,
                     model.layers[il].c_attn_attn_b);
+            ggml_format_name(cur, "l%d.attn_b", il);
         }
 
         // self-attention
@@ -579,6 +643,10 @@ struct ggml_cgraph * gpt2_graph(
             struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
             struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
 
+            ggml_format_name(Qcur, "l%d.Qcur", il);
+            ggml_format_name(Kcur, "l%d.Kcur", il);
+            ggml_format_name(Vcur, "l%d.Vcur", il);
+
             // store key and value to memory
             if (N >= 1) {
                 struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
@@ -596,6 +664,7 @@ struct ggml_cgraph * gpt2_graph(
                             Qcur,
                             ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
                         0, 2, 1, 3);
+            ggml_format_name(Q, "l%d.Q", il);
 
             // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
             // [64, n_past + N, 12]
@@ -605,6 +674,7 @@ struct ggml_cgraph * gpt2_graph(
                             ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
                             n_embd/n_head, n_head, n_past + N),
                         0, 2, 1, 3);
+            ggml_format_name(K, "l%d.K", il);
 
             // GG: flash attention
             //struct ggml_tensor * V =
@@ -621,6 +691,7 @@ struct ggml_cgraph * gpt2_graph(
             // K * Q
             // [n_past + N, N, 12]
             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            ggml_format_name(KQ, "l%d.KQ", il);
 
             // KQ_scaled = KQ / sqrt(n_embd/n_head)
             // [n_past + N, N, 12]
@@ -628,14 +699,17 @@ struct ggml_cgraph * gpt2_graph(
                 ggml_scale(ctx0,
                         KQ,
                         KQ_scale);
+            ggml_format_name(KQ_scaled, "l%d.KQ_scaled", il);
 
             // KQ_masked = mask_past(KQ_scaled)
             // [n_past + N, N, 12]
             struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+            ggml_format_name(KQ_masked, "l%d.KQ_masked", il);
 
             // KQ = soft_max(KQ_masked)
             // [n_past + N, N, 12]
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+            ggml_format_name(KQ_soft_max, "l%d.KQ_soft_max", il);
 
             // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
             // [n_past + N, 64, 12]
@@ -647,20 +721,24 @@ struct ggml_cgraph * gpt2_graph(
                                 n_embd/n_head, n_head, n_past + N),
                             1, 2, 0, 3),
                         ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
+            ggml_format_name(V_trans, "l%d.V_trans", il);
 
             // KQV = transpose(V) * KQ_soft_max
             // [64, N, 12]
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+            ggml_format_name(KQV, "l%d.KQV", il);
 
             // KQV_merged = KQV.permute(0, 2, 1, 3)
             // [64, 12, N]
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            ggml_format_name(KQV_merged, "l%d.KQV_merged", il);
 
             // cur = KQV_merged.contiguous().view(n_embd, N)
             // [768, N]
             cur = ggml_cpy(ctx0,
                     KQV_merged,
                     ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+            ggml_format_name(cur, "l%d.KQV_merged_contiguous", il);
         }
 
         // projection
@@ -675,14 +753,17 @@ struct ggml_cgraph * gpt2_graph(
             cur = ggml_mul_mat(ctx0,
                     model.layers[il].c_attn_proj_w,
                     cur);
+            ggml_format_name(cur, "l%d.attn_proj_w", il);
 
             cur = ggml_add(ctx0,
                     cur,
                     model.layers[il].c_attn_proj_b);
+            ggml_format_name(cur, "l%d.attn_proj_b", il);
         }
 
         // add the input
         cur = ggml_add(ctx0, cur, inpL);
+        ggml_format_name(cur, "l%d.add", il);
 
         struct ggml_tensor * inpFF = cur;
 
@@ -691,6 +772,7 @@ struct ggml_cgraph * gpt2_graph(
             // norm
             {
                 cur = ggml_norm(ctx0, inpFF, hparams.eps);
+                ggml_format_name(cur, "l%d.FFnorm", il);
 
                 // cur = ln_2_g*cur + ln_2_b
                 // [ 768, N]
@@ -699,6 +781,8 @@ struct ggml_cgraph * gpt2_graph(
                             cur,
                             model.layers[il].ln_2_g),
                         model.layers[il].ln_2_b);
+                ggml_format_name(cur, "l%d.ln_2_b", il);
+                ggml_format_name(cur->src[0], "l%d.ln_2_g", il);
             }
 
             // fully connected
@@ -712,14 +796,17 @@ struct ggml_cgraph * gpt2_graph(
             cur = ggml_mul_mat(ctx0,
                     model.layers[il].c_mlp_fc_w,
                     cur);
+            ggml_format_name(cur, "l%d.mlp_fc_w", il);
 
             cur = ggml_add(ctx0,
                     cur,
                     model.layers[il].c_mlp_fc_b);
+            ggml_format_name(cur, "l%d.mlp_fc_b", il);
 
             // GELU activation
             // [3072, N]
             cur = ggml_gelu(ctx0, cur);
+            ggml_format_name(cur, "l%d.gelu", il);
 
             // projection
             // [ 768, 3072] - model.layers[il].c_mlp_proj_w
@@ -732,20 +819,24 @@ struct ggml_cgraph * gpt2_graph(
             cur = ggml_mul_mat(ctx0,
                     model.layers[il].c_mlp_proj_w,
                     cur);
+            ggml_format_name(cur, "l%d.mlp_proj_w", il);
 
             cur = ggml_add(ctx0,
                     cur,
                     model.layers[il].c_mlp_proj_b);
+            ggml_format_name(cur, "l%d.mlp_proj_b", il);
         }
 
         // input for next layer
         inpL = ggml_add(ctx0, cur, inpFF);
+        ggml_format_name(inpL, "l%d.add2", il);
     }
 
     // norm
     {
         // [ 768, N]
         inpL = ggml_norm(ctx0, inpL, hparams.eps);
+        ggml_format_name(inpL, "out_norm");
 
         // inpL = ln_f_g*inpL + ln_f_b
         // [ 768, N]
@@ -754,12 +845,15 @@ struct ggml_cgraph * gpt2_graph(
                     inpL,
                     model.ln_f_g),
                 model.ln_f_b);
+        ggml_format_name(inpL, "out_ln_f_b");
+        ggml_format_name(inpL->src[0], "out_ln_f_g");
     }
 
     // inpL = WTE * inpL
     // [ 768, 50257] - model.lm_head
     // [ 768, N]     - inpL
     inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
+    ggml_format_name(inpL, "out_lm_head");
 
     // logits -> probs
     //inpL = ggml_soft_max(ctx0, inpL);
@@ -782,8 +876,7 @@ struct ggml_cgraph * gpt2_graph(
 //
 bool gpt2_eval(
         const gpt2_model & model,
-        struct ggml_allocr * allocr,
-        const int n_threads,
+        ggml_backend_sched_t sched,
         const int n_past,
         const std::vector<gpt_vocab::id> & embd_inp,
               std::vector<float>         & embd_w) {
@@ -793,24 +886,10 @@ bool gpt2_eval(
 
     const int n_vocab = hparams.n_vocab;
 
-    // reset the allocator to free all the memory allocated during the previous inference
-    ggml_allocr_reset(allocr);
-
-    struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, embd_inp);
-
-    // allocate tensors
-    ggml_allocr_alloc_graph(allocr, gf);
+    struct ggml_cgraph * gf = gpt2_graph(model, n_past, embd_inp);
 
     // run the computation
-    if (ggml_backend_is_cpu(model.backend)) {
-        ggml_backend_cpu_set_n_threads(model.backend, n_threads);
-    }
-#ifdef GGML_USE_METAL
-    if (ggml_backend_is_metal(model.backend)) {
-        ggml_backend_metal_set_n_cb(model.backend, n_threads);
-    }
-#endif
-    ggml_backend_graph_compute(model.backend, gf);
+    ggml_backend_sched_graph_compute(sched, gf);
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
@@ -862,7 +941,7 @@ int main(int argc, char ** argv) {
     {
         const int64_t t_start_us = ggml_time_us();
 
-        if (!gpt2_model_load(params.model, model, vocab, params.n_ctx, params.n_gpu_layers)) {
+        if (!gpt2_model_load(params.model, model, vocab, params)) {
             fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
             return 1;
         }
@@ -872,30 +951,34 @@ int main(int argc, char ** argv) {
         test_gpt_tokenizer(vocab, params.token_test);
     }
 
-    // keep this buffer alive while evaluating the model
-    ggml_backend_buffer_t buf_compute;
-
-    struct ggml_allocr * allocr = NULL;
-    // allocate the compute buffer
+    // create the backend scheduler
+    // the scheduler handles the allocation of the compute buffers and the scheduling of the computation between the different backends
+    ggml_backend_sched_t sched;
     {
-         // alignment required by the backend
-        size_t align = ggml_backend_get_alignment(model.backend);
-        allocr = ggml_allocr_new_measure(align);
+        // initialize the scheduler
+        sched = ggml_backend_sched_new(model.backends.data(), model.backends.size());
 
         // create the worst case graph for memory usage estimation
         int n_tokens = std::min(model.hparams.n_ctx, params.n_batch);
         int n_past = model.hparams.n_ctx - n_tokens;
-        struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, std::vector<gpt_vocab::id>(n_tokens, 0));
+        struct ggml_cgraph * gf = gpt2_graph(model, n_past, std::vector<gpt_vocab::id>(n_tokens, 0));
 
-        // compute the required memory
-        size_t mem_size = ggml_allocr_alloc_graph(allocr, gf);
+        ggml_backend_sched_init_measure(sched, gf);
 
-        // recreate the allocator with the required memory
-        ggml_allocr_free(allocr);
-        buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size);
-        allocr = ggml_allocr_new_from_buffer(buf_compute);
 
-        fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0/1024.0);
+        // compute the required memory
+        size_t mem_size = 0;
+        for (size_t i = 0; i < model.backends.size(); i++) {
+            ggml_backend_buffer_t buf = ggml_backend_sched_get_buffer(sched, model.backends[i]);
+            size_t size = ggml_backend_buffer_get_size(buf);
+            if (size > 0) {
+                mem_size += size;
+                printf("%s: %8s compute buffer size = %8.2f MB\n", __func__, ggml_backend_name(model.backends[i]), size/1024.0/1024.0);
+                //printf("%s: %8s compute buffer size = %zu bytes\n", __func__, ggml_backend_name(model.backends[i]), size);
+            }
+        }
+
+        printf("%s: total compute buffer size: %.2f MB\n", __func__, mem_size/1024.0/1024.0);
     }
 
     int n_past = 0;
@@ -926,7 +1009,7 @@ int main(int argc, char ** argv) {
         if (embd.size() > 0) {
             const int64_t t_start_us = ggml_time_us();
 
-            if (!gpt2_eval(model, allocr, params.n_threads, n_past, embd, logits)) {
+            if (!gpt2_eval(model, sched, n_past, embd, logits)) {
                 printf("Failed to predict\n");
                 return 1;
             }
@@ -975,7 +1058,7 @@ int main(int argc, char ** argv) {
         fflush(stdout);
 
         // end of text token
-        if (!params.ignore_eos && embd.back() == 50256) {
+        if (embd.back() == 50256) {
             break;
         }
     }
@@ -993,10 +1076,14 @@ int main(int argc, char ** argv) {
 
     ggml_free(model.ctx);
 
-    ggml_backend_buffer_free(model.buffer_w);
+    ggml_backend_sched_free(sched);
     ggml_backend_buffer_free(model.buffer_kv);
-    ggml_backend_buffer_free(buf_compute);
-    ggml_backend_free(model.backend);
+    for (auto & buf : model.buffers_w) {
+        ggml_backend_buffer_free(buf);
+    }
+    for (auto backend : model.backends) {
+        ggml_backend_free(backend);
+    }
 
     return 0;
 }
index c2c4c1fef9067c66adec5b95d038f6c1c3c31729..fbfbd4715f48496e17746badc59f7842ed79f418 100644 (file)
@@ -425,7 +425,7 @@ bool gptj_eval(
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     // KQ_pos - contains the positions
     struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
@@ -471,8 +471,8 @@ bool gptj_eval(
                         (   n_ctx)*ggml_element_size(model.memory_v),
                         (il*n_ctx)*ggml_element_size(model.memory_v)*n_embd + n_past*ggml_element_size(model.memory_v));
 
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
             }
 
             // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
@@ -590,8 +590,8 @@ bool gptj_eval(
     //inpL = ggml_soft_max_inplace(ctx0, inpL);
 
     // run the computation
-    ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_build_forward_expand(gf, inpL);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
index 43e16ce3e9ed8087c61615ea6ba23e7924a063a3..5fbe5148e6f33a6f041d9b79e187b43161fd44c3 100644 (file)
@@ -477,7 +477,7 @@ bool gpt_neox_eval(
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     // KQ_pos - contains the positions
     struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
@@ -537,8 +537,8 @@ bool gpt_neox_eval(
                         (   n_ctx)*ggml_element_size(model.memory_v),
                         (il*n_ctx)*ggml_element_size(model.memory_v)*n_embd + n_past*ggml_element_size(model.memory_v));
 
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
             }
 
             // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
@@ -653,8 +653,8 @@ bool gpt_neox_eval(
     //inpL = ggml_soft_max_inplace(ctx0, inpL);
 
     // run the computation
-    ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_build_forward_expand(gf, inpL);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
index 7949e9a6069768036ce7c0b419b7465288898936..b0135035bfc3d24b06c166a06abf4c9d90438352 100644 (file)
@@ -61,7 +61,7 @@ int mnist_eval(
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     struct ggml_tensor * input = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 28, 28, 1, 1);
     memcpy(input->data, digit.data(), ggml_nbytes(input));
@@ -86,16 +86,16 @@ int mnist_eval(
     ggml_tensor * probs = ggml_soft_max(ctx0, cur);
     ggml_set_name(probs, "probs");
 
-    ggml_build_forward_expand(&gf, probs);
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_build_forward_expand(gf, probs);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
     //ggml_graph_print(&gf);
-    ggml_graph_dump_dot(&gf, NULL, "mnist-cnn.dot");
+    ggml_graph_dump_dot(gf, NULL, "mnist-cnn.dot");
 
     if (fname_cgraph) {
         // export the compute graph for later use
         // see the "mnist-cpu" example
-        ggml_graph_export(&gf, fname_cgraph);
+        ggml_graph_export(gf, fname_cgraph);
 
         fprintf(stderr, "%s: exported compute graph to '%s'\n", __func__, fname_cgraph);
     }
index 6e1e39801ad88cac978f21c1cb37476a714d0248..3b759b0f3c0272de127ae0cdd50f90f0a6c53c5d 100644 (file)
@@ -39,10 +39,10 @@ int mnist_eval(
     struct ggml_context * ctx_data = NULL;
     struct ggml_context * ctx_eval = NULL;
 
-    struct ggml_cgraph gfi = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
+    struct ggml_cgraph gfi = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
 
     // param export/import test
-    GGML_ASSERT(ggml_graph_get_tensor(&gfi, "fc1_bias")->op_params[0] == int(0xdeadbeef));
+    GGML_ASSERT(ggml_graph_get_tensor(gfi, "fc1_bias")->op_params[0] == int(0xdeadbeef));
 
     // allocate work context
     // needed during ggml_graph_compute() to allocate a work tensor
@@ -57,12 +57,12 @@ int mnist_eval(
 
     struct ggml_context * ctx_work = ggml_init(params);
 
-    struct ggml_tensor * input = ggml_graph_get_tensor(&gfi, "input");
+    struct ggml_tensor * input = ggml_graph_get_tensor(gfi, "input");
     memcpy(input->data, digit.data(), ggml_nbytes(input));
 
-    ggml_graph_compute_with_ctx(ctx_work, &gfi, n_threads);
+    ggml_graph_compute_with_ctx(ctx_work, gfi, n_threads);
 
-    const float * probs_data = ggml_get_data_f32(ggml_graph_get_tensor(&gfi, "probs"));
+    const float * probs_data = ggml_get_data_f32(ggml_graph_get_tensor(gfi, "probs"));
 
     const int prediction = std::max_element(probs_data, probs_data + 10) - probs_data;
 
index a8d47ac9c70c315e66076aa5094fb77f1b044cdf..7d0eec8f186bb13c23e3f067cc9513a7850673be 100644 (file)
@@ -35,7 +35,7 @@ int mnist_eval(
     struct ggml_context * ctx_data = NULL;
     struct ggml_context * ctx_eval = NULL;
 
-    struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
+    struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
 
     // allocate work context
     static size_t buf_size = 128ull*1024*1024; // TODO
@@ -50,12 +50,12 @@ int mnist_eval(
     struct ggml_context * ctx_work = ggml_init(params);
 
     // this allocates all Metal resources and memory buffers
-    auto ctx_mtl = mnist_mtl_init(ctx_data, ctx_eval, ctx_work, &gf);
+    auto ctx_mtl = mnist_mtl_init(ctx_data, ctx_eval, ctx_work, gf);
 
     int prediction = -1;
 
     for (int i = 0; i < 1; ++i) {
-        struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "input");
+        struct ggml_tensor * input = ggml_graph_get_tensor(gf, "input");
 
         if (i % 2 == 0) {
             memcpy(input->data, digit.data(), ggml_nbytes(input));
@@ -64,7 +64,7 @@ int mnist_eval(
         }
 
         // the actual inference happens here
-        prediction = mnist_mtl_eval(ctx_mtl, &gf);
+        prediction = mnist_mtl_eval(ctx_mtl, gf);
     }
 
     mnist_mtl_free(ctx_mtl);
index 33986e4ef573205d7e6f912f632461a0fa16c620..a03b95341ece2ce9babac89540e8e0eabb49e368 100644 (file)
@@ -178,7 +178,7 @@ int mnist_eval(
 
     const auto & hparams = model.hparams;
 
-    static size_t buf_size = hparams.n_input * sizeof(float) * 4;
+    static size_t buf_size = hparams.n_input * sizeof(float) * 32;
     static void * buf = malloc(buf_size);
 
     struct ggml_init_params params = {
@@ -188,7 +188,7 @@ int mnist_eval(
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     struct ggml_tensor * input = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_input);
     memcpy(input->data, digit.data(), ggml_nbytes(input));
@@ -203,16 +203,16 @@ int mnist_eval(
     ggml_set_name(probs, "probs");
 
     // build / export / run the computation graph
-    ggml_build_forward_expand(&gf, probs);
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_build_forward_expand(gf, probs);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
     //ggml_graph_print   (&gf);
-    ggml_graph_dump_dot(&gf, NULL, "mnist.dot");
+    ggml_graph_dump_dot(gf, NULL, "mnist.dot");
 
     if (fname_cgraph) {
         // export the compute graph for later use
         // see the "mnist-cpu" example
-        ggml_graph_export(&gf, "mnist.ggml");
+        ggml_graph_export(gf, "mnist.ggml");
 
         fprintf(stderr, "%s: exported compute graph to '%s'\n", __func__, fname_cgraph);
     }
index 9b2fa02c22dbe0445570a03de826303b187d506d..b934e499430dada18adb5b0328f3cf9899808d91 100644 (file)
@@ -499,7 +499,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd));
@@ -544,8 +544,8 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
                     ggml_view_1d(ctx0, model.memory_v, N * n_embd,
                                  (ggml_element_size(model.memory_v) * n_embd) * (il * n_ctx + n_past));
 
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
             }
 
             // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0,
@@ -650,8 +650,8 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
     // inpL = ggml_soft_max(ctx0, inpL);
 
     // run the computation
-    ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_build_forward_expand(gf, inpL);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
     // std::cout << "Qcur" << std::endl;
     // print_tensor(Qcur);
index b8338cbbdeb681e70dc3c0b0619c5050943cbf95..f851470e9309441ffca943720e7215af2ef3a375 100644 (file)
@@ -476,7 +476,7 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd));
@@ -515,8 +515,8 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa
                     ggml_view_1d(ctx0, model.memory_v, N * n_embd,
                                  (ggml_element_size(model.memory_v) * n_embd) * (il * n_ctx + n_past));
 
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
             }
 
             // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0,
@@ -614,8 +614,8 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa
     // inpL = ggml_soft_max(ctx0, inpL);
 
     // run the computation
-    ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_build_forward_expand(gf, inpL);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
     // std::cout << "Qcur" << std::endl;
     // print_tensor(Qcur);
index 1de41f30a8e6db5d5c5c3441fe911bb981073cbe..7c4130fddace7adb53a16a0effa529529bc3298f 100644 (file)
@@ -2109,7 +2109,7 @@ int main(int argc, char ** argv) {
 
     static const size_t tensor_alignment = 32;
     {
-        state.buf_compute_img_enc.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
+        state.buf_compute_img_enc.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
         state.allocr = ggml_allocr_new_measure(tensor_alignment);
         struct ggml_cgraph * gf_measure = sam_encode_image(model, state, img1);
         if (!gf_measure) {
@@ -2144,7 +2144,7 @@ int main(int argc, char ** argv) {
         state.work_buffer.clear();
     }
     {
-        state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
+        state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
         state.allocr = ggml_allocr_new_measure(tensor_alignment);
 
         // TODO: user input
index 946210b85915adb700c76a57e44bb8d6501b3e3a..603473a36a109b195ade9284c6c995de9a07a76a 100644 (file)
@@ -464,7 +464,7 @@ bool starcoder_eval(
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@@ -528,8 +528,8 @@ bool starcoder_eval(
                 struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
                 struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
 
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
             }
 
             // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
@@ -716,8 +716,8 @@ bool starcoder_eval(
     //inpL = ggml_soft_max_inplace(ctx0, inpL);
 
     // run the computation
-    ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_build_forward_expand(gf, inpL);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
index a224115a1480518643575bcdb5dcfd32340dee4a..1ab039c310a0eb94b8982bbc9a2615ea440167d6 100644 (file)
@@ -673,7 +673,7 @@ bool starcoder_eval(
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
 
@@ -739,8 +739,8 @@ bool starcoder_eval(
                 struct ggml_tensor * k = ggml_view_1d(ctx0, cache.k, N*n_embd, (ggml_element_size(cache.k)*n_embd)*(il*n_ctx + n_past));
                 struct ggml_tensor * v = ggml_view_1d(ctx0, cache.v, N*n_embd, (ggml_element_size(cache.v)*n_embd)*(il*n_ctx + n_past));
 
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
             }
 
             // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
@@ -927,8 +927,8 @@ bool starcoder_eval(
     //inpL = ggml_soft_max_inplace(ctx0, inpL);
 
     // run the computation
-    ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_build_forward_expand(gf, inpL);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
index 1224be9bf51bcb4ae10dc5adfef9015aca807da0..6d4a2c127186fe1a8fd2068d9410f6a7adadaab2 100644 (file)
@@ -655,7 +655,7 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::funct
     auto & meta  = allocr.meta;
     auto & data  = allocr.data;
 
-    meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
+    meta.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
 
     alloc = ggml_allocr_new_measure(tensor_alignment);
 
@@ -5413,7 +5413,7 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
     // b: N*N*sizeof(float)
     // c: N*N*sizeof(float)
     // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
-    std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead());
+    std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead());
     std::vector<uint8_t> work;
 
     // put a bunch of random data in the buffer
@@ -5464,17 +5464,19 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
 
             struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b);
 
-            struct ggml_cgraph gf = ggml_build_forward(c);
+            struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+            ggml_build_forward_expand(gf, c);
 
             double tsum = 0.0;
 
             // heat-up
-            ggml_graph_compute_helper(work, &gf, n_threads);
+            ggml_graph_compute_helper(work, gf, n_threads);
 
             for (int i = 0; i < n_max; ++i) {
                 const int64_t t0 = ggml_time_us();
 
-                ggml_graph_compute_helper(work, &gf, n_threads);
+                ggml_graph_compute_helper(work, gf, n_threads);
 
                 const int64_t t1 = ggml_time_us();
 
index 7c69fce4449238c068efa019f100e3a757274c19..ef848f3bc0754dae8ab2c10956ee91f2ed6c1288 100644 (file)
@@ -109,7 +109,7 @@ static bool load_labels(const char * filename, std::vector<std::string> & labels
     std::string line;
     while (std::getline(file_in, line)) {
         labels.push_back(line);
-    }    
+    }
     GGML_ASSERT(labels.size() == 80);
     return true;
 }
@@ -365,7 +365,7 @@ void detect(yolo_image & img, const yolo_model & model, float thresh, const std:
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
     std::vector<detection> detections;
 
     yolo_image sized = letterbox_image(img, model.width, model.height);
@@ -420,9 +420,9 @@ void detect(yolo_image & img, const yolo_model & model, float thresh, const std:
     struct ggml_tensor * layer_22 = result;
     print_shape(22, result);
 
-    ggml_build_forward_expand(&gf, layer_15);
-    ggml_build_forward_expand(&gf, layer_22);
-    ggml_graph_compute_with_ctx(ctx0, &gf, 1);
+    ggml_build_forward_expand(gf, layer_15);
+    ggml_build_forward_expand(gf, layer_22);
+    ggml_graph_compute_with_ctx(ctx0, gf, 1);
 
     yolo_layer yolo16{ 80, {3, 4, 5}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_15};
     apply_yolo(yolo16);
index e38758878b91a36b886c21732403809c78765557..dde2a06bf803098a85d60c15a47953753e4a0bd8 100644 (file)
@@ -6,27 +6,79 @@
 extern "C" {
 #endif
 
+struct ggml_backend;
 struct ggml_backend_buffer;
 
-GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
-GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
-GGML_API struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer);
+//
+// Legacy API
+//
+
+typedef struct ggml_allocr * ggml_allocr_t;
+
+// initialize allocator for use with CPU backend only
+GGML_API ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment);
+GGML_API ggml_allocr_t ggml_allocr_new_measure(size_t alignment);
+
+// initialize allocator for use with ggml-backend
+GGML_API ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer);
+GGML_API ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
+GGML_API ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend);
+
+GGML_API struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc);
 
 // tell the allocator to parse nodes following the order described in the list
 // you should call this if your graph are optimized to execute out-of-order
-GGML_API void   ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
-
-GGML_API void   ggml_allocr_free       (struct ggml_allocr * alloc);
-GGML_API bool   ggml_allocr_is_measure (struct ggml_allocr * alloc);
-GGML_API void   ggml_allocr_reset      (struct ggml_allocr * alloc);
-GGML_API void   ggml_allocr_alloc      (struct ggml_allocr * alloc, struct ggml_tensor * tensor);
-GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
-GGML_API size_t ggml_allocr_max_size   (struct ggml_allocr * alloc);
-
-GGML_API size_t ggml_allocr_alloc_graph_n(
-                    struct ggml_allocr * alloc,
-                    struct ggml_cgraph ** graphs, int n_graphs,
-                    struct ggml_tensor *** inputs, struct ggml_tensor *** outputs);
+GGML_API void   ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n);
+
+GGML_API void   ggml_allocr_free       (ggml_allocr_t alloc);
+GGML_API bool   ggml_allocr_is_measure (ggml_allocr_t alloc);
+GGML_API void   ggml_allocr_reset      (ggml_allocr_t alloc);
+GGML_API void   ggml_allocr_alloc      (ggml_allocr_t alloc, struct ggml_tensor * tensor);
+GGML_API size_t ggml_allocr_max_size   (ggml_allocr_t alloc);
+
+GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph);
+
+//
+// ggml-backend v2 API
+//
+
+// Seperate tensor and graph allocator objects
+// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
+// The original API is kept as a wrapper around the new API
+
+// Tensor allocator
+typedef struct ggml_tallocr * ggml_tallocr_t;
+
+GGML_API ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment);
+GGML_API ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment);
+GGML_API ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer);
+GGML_API ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
+GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend);
+
+GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc);
+
+GGML_API void   ggml_tallocr_free       (ggml_tallocr_t talloc);
+GGML_API bool   ggml_tallocr_is_measure (ggml_tallocr_t talloc);
+GGML_API void   ggml_tallocr_reset      (ggml_tallocr_t talloc);
+GGML_API void   ggml_tallocr_alloc      (ggml_tallocr_t talloc, struct ggml_tensor * tensor);
+GGML_API size_t ggml_tallocr_max_size   (ggml_tallocr_t talloc);
+
+
+// Graph allocator
+typedef struct ggml_gallocr * ggml_gallocr_t;
+
+GGML_API ggml_gallocr_t ggml_gallocr_new(void);
+GGML_API void   ggml_gallocr_free(ggml_gallocr_t galloc);
+
+GGML_API void   ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n);
+GGML_API size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph);
+
+// Allocate tensors from the allocators given by the hash table
+GGML_API void   ggml_gallocr_alloc_graph_n(
+                    ggml_gallocr_t galloc,
+                    struct ggml_cgraph * graph,
+                    struct ggml_hash_set hash_set,
+                    ggml_tallocr_t * hash_node_talloc);
 
 #ifdef  __cplusplus
 }
index da134b0dbed514423b1223814b2346c4048a2854..966687320ac96d971e9d635429d4ea1018a7ce57 100644 (file)
@@ -1,51 +1,20 @@
 #pragma once
 
 #include "ggml.h"
+#include "ggml-alloc.h"
 
 #ifdef  __cplusplus
 extern "C" {
 #endif
-    struct ggml_backend;
-    struct ggml_backend_buffer;
-
-    // type-erased backend-specific types / wrappers
-    typedef void * ggml_backend_context_t;
-    typedef void * ggml_backend_graph_plan_t;
-    typedef void * ggml_backend_buffer_context_t;
-
-    // avoid accessing internals of these types
-    typedef struct ggml_backend        * ggml_backend_t;
-    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
 
     //
-    // backend buffer
+    // Backend buffer
     //
 
-    struct ggml_backend_buffer_i {
-        void   (*free_buffer)   (ggml_backend_buffer_t buffer);
-        void * (*get_base)      (ggml_backend_buffer_t buffer); // get base pointer
-        size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
-        void   (*init_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
-        void   (*free_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
-    };
-
-    // TODO: hide behind API
-    struct ggml_backend_buffer {
-        struct ggml_backend_buffer_i iface;
-
-        ggml_backend_t                backend;
-        ggml_backend_buffer_context_t context;
-
-        size_t size;
-    };
+    struct ggml_backend_buffer;
+    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
 
     // backend buffer functions
-    GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
-            struct ggml_backend                  * backend,
-            struct ggml_backend_buffer_i           iface,
-                   ggml_backend_buffer_context_t   context,
-                   size_t                          size);
-
     GGML_API void   ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
     GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
     GGML_API void * ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
@@ -55,50 +24,13 @@ extern "C" {
     GGML_API void   ggml_backend_buffer_free_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
 
     //
-    // backend
+    // Backend
     //
 
-    struct ggml_backend_i {
-        const char * (*get_name)(ggml_backend_t backend);
-
-        void (*free)(ggml_backend_t backend);
-
-        // buffer allocation
-        ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
-
-        // get buffer alignment
-        size_t (*get_alignment)(ggml_backend_t backend);
-
-        // tensor data access
-        // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
-        void (*set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
-        void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
-        void (*synchronize)     (ggml_backend_t backend);
-
-        // (optional) copy tensor between different backends, allow for single-copy tranfers
-        void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
-        void (*cpy_tensor_to)  (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
-
-        // compute graph with a plan
-        ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
-        void                      (*graph_plan_free)   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
-        void                      (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
-
-        // compute graph without a plan
-        void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
-
-        // check if the backend supports an operation
-        bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
-    };
-
-    // TODO: hide behind API
-    struct ggml_backend {
-        struct ggml_backend_i iface;
-
-        ggml_backend_context_t context;
-    };
+    struct ggml_backend;
+    typedef struct ggml_backend * ggml_backend_t;
+    typedef void * ggml_backend_graph_plan_t;
 
-    // backend helper functions
     GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
 
     GGML_API const char * ggml_backend_name(ggml_backend_t backend);
@@ -133,11 +65,72 @@ extern "C" {
     GGML_API ggml_backend_t ggml_backend_cpu_init(void);
 
     GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend);
-
     GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
 
+    // Create a backend buffer from an existing pointer
     GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
 
+
+    //
+    // Backend scheduler
+    //
+
+    // The backend scheduler allows for multiple backends to be used together
+    // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
+    // The backends are selected based on:
+    // - the backend that supports the operation
+    // - the location of the pre-allocated tensors (e.g. the weights)
+    /*
+      Example usage:
+
+        sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends);
+        // sched is initialized with measure allocators and cannot be used until allocated with a measure graph
+
+        // initialize buffers from a measure graph
+        measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed
+
+        // in build_graph:
+        build_graph(...) {
+            // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer)
+            alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu);
+            ggml_allocr_alloc(alloc_cpu, tensor);
+
+            // manually assigning nodes to a backend (optional, shouldn't be needed in most cases)
+            struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
+            ggml_backend_sched_set_node_backend(sched, node, backend_gpu);
+        }
+
+        // allocate backend buffers from measure graph
+        ggml_backend_sched_init_measure(sched, measure_graph);
+
+        // the scheduler is now ready to compute graphs
+
+        // compute
+        graph = build_graph(sched);
+        ggml_backend_sched_graph_compute(sched, graph);
+    */
+
+    struct ggml_backend_sched;
+    typedef struct ggml_backend_sched * ggml_backend_sched_t;
+
+    // Initialize a backend scheduler
+    GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends);
+
+    GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
+
+    // Initialize backend buffers from a measure graph
+    GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
+
+    GGML_API ggml_tallocr_t        ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend);
+    GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend);
+
+    GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
+
+    // Allocate a graph on the backend scheduler
+    GGML_API void ggml_backend_sched_graph_compute(
+            ggml_backend_sched_t sched,
+            struct ggml_cgraph * graph);
+
 #ifdef  __cplusplus
 }
 #endif
index a8c2a4ad940331a99eb92590eff7e788b7d01b80..253e4e27f68704ee0a15ad27720254fc8083cc68 100644 (file)
@@ -58,7 +58,8 @@
 //   {
 //       ...
 //
-//       struct ggml_cgraph gf = ggml_build_forward(f);
+//       struct ggml_cgraph * gf = ggml_new_graph(ctx);
+//       ggml_build_forward_expand(gf, f);
 //
 //       // set the input variable and parameter values
 //       ggml_set_f32(x, 2.0f);
 #define GGML_QNT_VERSION        2    // bump this on quantization format changes
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 
-#define GGML_MAX_DIMS          4
-#define GGML_MAX_NODES         16384
-#define GGML_MAX_PARAMS        1024
-#define GGML_MAX_CONTEXTS      64
-#define GGML_MAX_SRC           6
-#define GGML_MAX_NAME          64
-#define GGML_MAX_OP_PARAMS     32
-#define GGML_DEFAULT_N_THREADS 4
-
+#define GGML_MAX_DIMS           4
+#define GGML_MAX_PARAMS         1024
+#define GGML_MAX_CONTEXTS       64
+#define GGML_MAX_SRC            6
+#define GGML_MAX_NAME           64
+#define GGML_MAX_OP_PARAMS      32
+#define GGML_DEFAULT_N_THREADS  4
+#define GGML_DEFAULT_GRAPH_SIZE 2048
 #if UINTPTR_MAX == 0xFFFFFFFF
     #define GGML_MEM_ALIGN 4
 #else
     do { \
         if (!(x)) { \
             fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
-            abort(); \
+            fflush(stderr); \
+            fflush(stdout); \
+            ggml_print_backtrace(); \
+            exit(1); \
         } \
     } while (0)
 
@@ -532,37 +535,33 @@ extern "C" {
 
         int n_threads;
 
-        // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
-        int n_tasks[GGML_MAX_NODES];
-
         // abort ggml_graph_compute when true
         bool (*abort_callback)(void * data);
         void * abort_callback_data;
     };
 
-    // next prime after GGML_MAX_NODES
-    // #define GGML_GRAPH_HASHTABLE_SIZE 4099
-    // next prime after GGML_MAX_NODES * 2 (nodes + leafs)
-    // #define GGML_GRAPH_HASHTABLE_SIZE 8273
-    // #define GGML_GRAPH_HASHTABLE_SIZE 16411
-    #define GGML_GRAPH_HASHTABLE_SIZE 32771
-
     enum ggml_cgraph_eval_order {
         GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
         GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
         GGML_CGRAPH_EVAL_ORDER_COUNT
     };
 
+    struct ggml_hash_set {
+        size_t size;
+        struct ggml_tensor ** keys;
+    };
+
     // computation graph
     struct ggml_cgraph {
+        int size;
         int n_nodes;
         int n_leafs;
 
-        struct ggml_tensor * nodes[GGML_MAX_NODES];
-        struct ggml_tensor * grads[GGML_MAX_NODES];
-        struct ggml_tensor * leafs[GGML_MAX_NODES];
+        struct ggml_tensor ** nodes;
+        struct ggml_tensor ** grads;
+        struct ggml_tensor ** leafs;
 
-        void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
+        struct ggml_hash_set visited_hash_table;
 
         enum ggml_cgraph_eval_order order;
 
@@ -572,8 +571,6 @@ extern "C" {
         int64_t perf_time_us;
     };
 
-    static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph);
-
     // scratch buffer
     struct ggml_scratch {
         size_t offs;
@@ -618,6 +615,8 @@ extern "C" {
     GGML_API int64_t ggml_cycles(void);
     GGML_API int64_t ggml_cycles_per_ms(void);
 
+    GGML_API void    ggml_print_backtrace(void);
+
     GGML_API void    ggml_numa_init(void); // call once for better performance on NUMA systems
     GGML_API bool    ggml_is_numa(void); // true if init detected that system has >1 NUMA node
 
@@ -1720,19 +1719,22 @@ extern "C" {
     GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
     GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
 
-    GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
-    GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
-
     // graph allocation in a context
-    GGML_API struct ggml_cgraph * ggml_new_graph        (struct ggml_context * ctx);
-    GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor);
+    GGML_API struct ggml_cgraph * ggml_new_graph         (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
+    GGML_API struct ggml_cgraph * ggml_new_graph_custom  (struct ggml_context * ctx, size_t size, bool grads);
+    GGML_API struct ggml_cgraph * ggml_graph_dup         (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+    GGML_API struct ggml_cgraph * ggml_graph_view        (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1);
+    GGML_API void                 ggml_graph_cpy         (struct ggml_cgraph * src, struct ggml_cgraph * dst);
+    GGML_API void                 ggml_graph_reset       (struct ggml_cgraph * cgraph);  // zero grads
+    GGML_API void                 ggml_graph_clear       (struct ggml_cgraph * cgraph);
+
     GGML_API size_t ggml_graph_overhead(void);
+    GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
 
     // ggml_graph_plan() has to be called before ggml_graph_compute()
     // when plan.work_size > 0, caller must allocate memory for plan.work_data
     GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
-    GGML_API               int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
-    GGML_API              void ggml_graph_reset  (struct ggml_cgraph * cgraph);
+    GGML_API int               ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
 
     // same as ggml_graph_compute() but the work data is allocated as a part of the context
     // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
@@ -1740,8 +1742,8 @@ extern "C" {
 
     GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
 
-    GGML_API void               ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
-    GGML_API struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
+    GGML_API void                 ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
+    GGML_API struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
 
     // print info and performance information for the graph
     GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
@@ -1804,6 +1806,8 @@ extern "C" {
     struct ggml_opt_params {
         enum ggml_opt_type type;
 
+        size_t graph_size;
+
         int n_threads;
 
         // delta-based convergence test
index 50c0b92c4a7b9eec1c13f828db5fb52502e5a135..192e01eae8f0b87159383700c2873dc30188dc7b 100644 (file)
@@ -253,6 +253,8 @@ add_library(${TARGET}
     ggml.c
     ggml-alloc.c
     ggml-backend.c
+    ggml-impl.h
+    ggml-backend-impl.h
     ../include/ggml/ggml.h
     ../include/ggml/ggml-alloc.h
     ../include/ggml/ggml-backend.h
@@ -323,7 +325,9 @@ endif()
 
 set (GGML_PUBLIC_HEADERS
      ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml/ggml.h
-     ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml/ggml-alloc.h)
+     ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml/ggml-alloc.h
+     ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml/ggml-backend.h)
+
 set_target_properties(${TARGET} PROPERTIES
                       PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
 
index 34eba3f830e8496a7d9bc77edb3b0ddc2bd93017..f400dc56a20d054ee7d8879447bec42c2732d59e 100644 (file)
@@ -1,51 +1,21 @@
 #include "ggml-alloc.h"
-#include "ggml-backend.h"
+#include "ggml-backend-impl.h"
 #include "ggml.h"
+#include "ggml-impl.h"
 #include <assert.h>
+#include <limits.h>
 #include <stdarg.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
 
-
-#define UNUSED(x) (void)(x)
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
-#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
+#define MAX_FREE_BLOCKS 256
 
 //#define GGML_ALLOCATOR_DEBUG
 
-//#define AT_PRINTF printf
-#define AT_PRINTF(...) ((void)0)
-
-struct hash_node {
-    struct ggml_tensor * t;
-    int n_children;
-    int n_views;
-};
-
-static size_t hash(void * p) {
-    return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
-}
-
-static struct hash_node * hash_get(struct hash_node hash_table[], struct ggml_tensor * t) {
-    size_t h = hash(t);
-
-    // linear probing
-    size_t i = h;
-    while (hash_table[i].t != NULL) {
-        if (hash_table[i].t == t) {
-            return &hash_table[i];
-        }
-        i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
-        if (i == h) {
-            // hash table is full
-            GGML_ASSERT(false);
-        }
-    }
-
-    hash_table[i].t = t;
-    return &hash_table[i];
-}
+//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__)
+#define AT_PRINTF(...)
 
 // TODO: GGML_PAD ?
 static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
@@ -59,20 +29,18 @@ struct free_block {
     size_t size;
 };
 
-#define MAX_FREE_BLOCKS 256
-
-struct ggml_allocr {
+struct ggml_tallocr {
     struct ggml_backend_buffer * buffer;
     bool buffer_owned;
-    void * data;
+    void * base;
     size_t alignment;
+
     int n_free_blocks;
     struct free_block free_blocks[MAX_FREE_BLOCKS];
-    struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
+
     size_t max_size;
+
     bool measure;
-    int parse_seq[GGML_MAX_CONCUR];
-    int parse_seq_len;
 
 #ifdef GGML_ALLOCATOR_DEBUG
     struct ggml_tensor * allocated_tensors[1024];
@@ -80,7 +48,7 @@ struct ggml_allocr {
 };
 
 #ifdef GGML_ALLOCATOR_DEBUG
-static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+static void add_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
     for (int i = 0; i < 1024; i++) {
         if (alloc->allocated_tensors[i] == NULL) {
             alloc->allocated_tensors[i] = tensor;
@@ -89,7 +57,7 @@ static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor
     }
     GGML_ASSERT(!"out of allocated_tensors");
 }
-static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
     for (int i = 0; i < 1024; i++) {
         if (alloc->allocated_tensors[i] == tensor ||
             (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
@@ -103,7 +71,7 @@ static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tens
 #endif
 
 // check if a tensor is allocated by this buffer
-static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) {
+static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
     return tensor->buffer == alloc->buffer;
 }
 
@@ -111,7 +79,7 @@ static bool ggml_is_view(struct ggml_tensor * t) {
     return t->view_src != NULL;
 }
 
-void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
     GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
     GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
 
@@ -162,9 +130,10 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
     }
 
     tensor->data = addr;
-    AT_PRINTF("%s: allocated data at %p\n", __func__, tensor->data);
     tensor->buffer = alloc->buffer;
-    ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
+    if (!alloc->measure) {
+        ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
+    }
 
 #ifdef GGML_ALLOCATOR_DEBUG
     add_allocated_tensor(alloc, tensor);
@@ -180,16 +149,16 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
     }
 #endif
 
-    alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
+    alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->base + size);
 }
 
 // this is a very naive implementation, but for our case the number of free blocks should be very small
-static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
-    if (ggml_allocr_is_own(alloc, tensor) == false) {
+static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
+    if (ggml_tallocr_is_own(alloc, tensor) == false) {
         // the tensor was not allocated in this buffer
         // this can happen because the graph allocator will try to free weights and other tensors from different buffers
         // the easiest way to deal with this is just to ignore it
-        AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer);
+        // AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer);
         return;
     }
 
@@ -199,7 +168,9 @@ static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tens
     size = aligned_offset(NULL, size, alloc->alignment);
     AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
 
-    ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
+    if (!alloc->measure) {
+        ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
+    }
 
 #ifdef GGML_ALLOCATOR_DEBUG
     remove_allocated_tensor(alloc, tensor);
@@ -253,91 +224,180 @@ static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tens
     alloc->n_free_blocks++;
 }
 
-void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
-    for (int i = 0; i < n; i++) {
-        alloc->parse_seq[i] = list[i];
-    }
-    alloc->parse_seq_len = n;
-}
-
-void ggml_allocr_reset(struct ggml_allocr * alloc) {
+void ggml_tallocr_reset(ggml_tallocr_t alloc) {
     alloc->n_free_blocks = 1;
-    size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
-    alloc->free_blocks[0].addr = (char *)alloc->data + align_offset;
-    alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
+    size_t align_offset = aligned_offset(alloc->base, 0, alloc->alignment);
+    alloc->free_blocks[0].addr = (char *)alloc->base + align_offset;
+
+    if (alloc->measure) {
+        alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
+    } else {
+        alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
+    }
 }
 
-struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) {
+ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
     struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
 
-    struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr));
+    ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
 
-    *alloc = (struct ggml_allocr){
+    *alloc = (struct ggml_tallocr) {
         /*.buffer        = */ buffer,
         /*.buffer_owned  = */ true,
         /*.base          = */ ggml_backend_buffer_get_base(buffer),
         /*.alignment     = */ alignment,
         /*.n_free_blocks = */ 0,
         /*.free_blocks   = */ {{0}},
-        /*.hash_table    = */ {{0}},
         /*.max_size      = */ 0,
         /*.measure       = */ false,
-        /*.parse_seq     = */ {0},
-        /*.parse_seq_len = */ 0,
 #ifdef GGML_ALLOCATOR_DEBUG
         /*.allocated_tensors = */ {0},
 #endif
     };
 
-    ggml_allocr_reset(alloc);
+    ggml_tallocr_reset(alloc);
+
+    return alloc;
+}
+
+ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment) {
+    ggml_tallocr_t alloc = ggml_tallocr_new((void *)0x1000, SIZE_MAX/2, alignment);
+    alloc->measure = true;
 
     return alloc;
 }
 
-struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
-    struct ggml_allocr * alloc = ggml_allocr_new((void *)0x1000, (size_t)-0x1001, alignment);
+ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend) {
+    // create a backend buffer to get the correct tensor allocation sizes
+    ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, 1);
+
+    // TODO: move alloc initialization to a common ggml_tallocr_new_impl function
+    ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
+    alloc->buffer_owned = true;
     alloc->measure = true;
+    ggml_tallocr_reset(alloc);
+    return alloc;
+}
 
+ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size) {
+    ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, size);
+    ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
+    alloc->buffer_owned = true;
     return alloc;
 }
 
-struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
-    struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr));
+ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
+    ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
 
-    *alloc = (struct ggml_allocr){
+    *alloc = (struct ggml_tallocr) {
         /*.buffer        = */ buffer,
         /*.buffer_owned  = */ false,
         /*.base          = */ ggml_backend_buffer_get_base(buffer),
         /*.alignment     = */ ggml_backend_buffer_get_alignment(buffer),
         /*.n_free_blocks = */ 0,
         /*.free_blocks   = */ {{0}},
-        /*.hash_table    = */ {{0}},
         /*.max_size      = */ 0,
         /*.measure       = */ false,
-        /*.parse_seq     = */ {0},
-        /*.parse_seq_len = */ 0,
 #ifdef GGML_ALLOCATOR_DEBUG
         /*.allocated_tensors = */ {0},
 #endif
     };
 
-    ggml_allocr_reset(alloc);
+    ggml_tallocr_reset(alloc);
 
     return alloc;
 }
 
-void ggml_allocr_free(struct ggml_allocr * alloc) {
+struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t alloc) {
+    return alloc->buffer;
+}
+
+void ggml_tallocr_free(ggml_tallocr_t alloc) {
+    if (alloc == NULL) {
+        return;
+    }
+
     if (alloc->buffer_owned) {
         ggml_backend_buffer_free(alloc->buffer);
     }
     free(alloc);
 }
 
-bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
+bool ggml_tallocr_is_measure(ggml_tallocr_t alloc) {
     return alloc->measure;
 }
 
-//////////// compute graph allocator
+size_t ggml_tallocr_max_size(ggml_tallocr_t alloc) {
+    return alloc->max_size;
+}
+
+// graph allocator
+
+struct hash_node {
+    int n_children;
+    int n_views;
+};
+
+struct ggml_gallocr {
+    ggml_tallocr_t talloc;
+    struct ggml_hash_set hash_set;
+    struct hash_node * hash_values;
+    size_t hash_values_size;
+    ggml_tallocr_t * hash_allocs;
+    int * parse_seq;
+    int parse_seq_len;
+};
+
+ggml_gallocr_t ggml_gallocr_new(void) {
+    ggml_gallocr_t galloc = (ggml_gallocr_t)malloc(sizeof(struct ggml_gallocr));
+
+    *galloc = (struct ggml_gallocr) {
+        /*.talloc           = */ NULL,
+        /*.hash_set         = */ {0},
+        /*.hash_values      = */ NULL,
+        /*.hash_values_size = */ 0,
+        /*.hash_allocs      = */ NULL,
+        /*.parse_seq        = */ NULL,
+        /*.parse_seq_len    = */ 0,
+    };
+
+    return galloc;
+}
+
+void ggml_gallocr_free(ggml_gallocr_t galloc) {
+    if (galloc == NULL) {
+        return;
+    }
+
+    if (galloc->hash_set.keys != NULL) {
+        free(galloc->hash_set.keys);
+    }
+    if (galloc->hash_values != NULL) {
+        free(galloc->hash_values);
+    }
+    if (galloc->hash_allocs != NULL) {
+        free(galloc->hash_allocs);
+    }
+    if (galloc->parse_seq != NULL) {
+        free(galloc->parse_seq);
+    }
+    free(galloc);
+}
+
+void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n) {
+    free(galloc->parse_seq);
+    galloc->parse_seq = malloc(sizeof(int) * n);
+
+    for (int i = 0; i < n; i++) {
+        galloc->parse_seq[i] = list[i];
+    }
+    galloc->parse_seq_len = n;
+}
+
+static struct hash_node * hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+    size_t i = ggml_hash_find_or_insert(galloc->hash_set, t);
+    return &galloc->hash_values[i];
+}
 
 static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
     if (a->type != b->type) {
@@ -378,23 +438,38 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
     }
 }
 
-static void init_view(struct ggml_allocr * alloc, struct ggml_tensor * view) {
-    assert(view->view_src != NULL && view->view_src->data != NULL);
+static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * node) {
+    if (galloc->talloc != NULL) {
+        return galloc->talloc;
+    }
+
+    return galloc->hash_allocs[ggml_hash_find_or_insert(galloc->hash_set, node)];
+}
+
+static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view) {
+    ggml_tallocr_t alloc = node_tallocr(galloc, view);
+
+    //printf("init_view: %s from src %s\n", view->name, view->view_src->name);
+    GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
     view->backend = view->view_src->backend;
     view->buffer  = view->view_src->buffer;
     view->data    = (char *)view->view_src->data + view->view_offs;
 
     // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
     // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
-    assert(ggml_allocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
-    ggml_backend_buffer_init_tensor(alloc->buffer, view);
+    assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
+
+    if (!alloc->measure) {
+        ggml_backend_buffer_init_tensor(alloc->buffer, view);
+    }
 }
 
-static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) {
-    struct hash_node * ht = alloc->hash_table;
+static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
+    ggml_tallocr_t alloc = node_tallocr(galloc, node);
+
     if (node->data == NULL) {
         if (ggml_is_view(node)) {
-            init_view(alloc, node);
+            init_view(galloc, node);
         } else {
             // see if we can reuse a parent's buffer (inplace)
             if (ggml_op_can_inplace(node->op)) {
@@ -405,16 +480,16 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
                     }
 
                     // if the node's data is external, then we cannot re-use it
-                    if (ggml_allocr_is_own(alloc, parent) == false) {
+                    if (ggml_tallocr_is_own(alloc, parent) == false) {
                         AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
                         continue;
                     }
 
-                    struct hash_node * p_hn = hash_get(ht, parent);
+                    struct hash_node * p_hn = hash_get(galloc, parent);
                     if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
                         if (ggml_is_view(parent)) {
                             struct ggml_tensor * view_src = parent->view_src;
-                            struct hash_node * view_src_hn = hash_get(ht, view_src);
+                            struct hash_node * view_src_hn = hash_get(galloc, view_src);
                             if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
                                 // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
                                 // the parent's data that it will need later (same layout requirement). the problem is that then
@@ -424,7 +499,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
                                 AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
                                 node->view_src = view_src;
                                 view_src_hn->n_views += 1;
-                                init_view(alloc, node);
+                                init_view(galloc, node);
                                 return;
                             }
                         }
@@ -432,163 +507,260 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
                             AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
                             node->view_src = parent;
                             p_hn->n_views += 1;
-                            init_view(alloc, node);
+                            init_view(galloc, node);
                             return;
                         }
                     }
                 }
             }
-            ggml_allocr_alloc(alloc, node);
+            ggml_tallocr_alloc(alloc, node);
         }
     }
 }
 
-size_t ggml_allocr_alloc_graph_n(
-    struct ggml_allocr * alloc,
-    struct ggml_cgraph ** graphs, int n_graphs,
-    struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
+static void free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
+    ggml_tallocr_t alloc = node_tallocr(galloc, node);
 
-    // reset hash table
-    struct hash_node * ht = alloc->hash_table;
-    memset(ht, 0, sizeof(struct hash_node) * GGML_GRAPH_HASHTABLE_SIZE);
+    ggml_tallocr_free_tensor(alloc, node);
+}
+
+static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * gf) {
+    const int * parse_seq     = galloc->parse_seq;
+    int         parse_seq_len = galloc->parse_seq_len;
 
     // count number of children and views
-    for (int g = 0; g < n_graphs; g++) {
-        struct ggml_cgraph * gf = graphs[g];
-        for (int i = 0; i < gf->n_nodes; i++) {
+    for (int i = 0; i < gf->n_nodes; i++) {
+        struct ggml_tensor * node = gf->nodes[i];
+
+        if (ggml_is_view(node)) {
+            struct ggml_tensor * view_src = node->view_src;
+            hash_get(galloc, view_src)->n_views += 1;
+            if (node->buffer == NULL && node->data != NULL) {
+                // view of a pre-allocated tensor, didn't call init_view() yet
+                init_view(galloc, node);
+            }
+        }
+
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * parent = node->src[j];
+            if (parent == NULL) {
+                break;
+            }
+            hash_get(galloc, parent)->n_children += 1;
+            if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
+                init_view(galloc, parent);
+            }
+        }
+   }
+
+    // allocate tensors
+    // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
+    int last_barrier_pos = 0;
+    int n_nodes = parse_seq_len ? parse_seq_len : gf->n_nodes;
+
+    for (int ind = 0; ind < n_nodes; ind++) {
+        // allocate a node if there is no parse_seq or this is not a barrier
+        if (parse_seq_len == 0 || parse_seq[ind] != -1) {
+            int i = parse_seq_len ? parse_seq[ind] : ind;
             struct ggml_tensor * node = gf->nodes[i];
 
-            if (ggml_is_view(node)) {
-                struct ggml_tensor * view_src = node->view_src;
-                hash_get(ht, view_src)->n_views += 1;
-                if (node->buffer == NULL && node->data != NULL) {
-                    // view of a pre-allocated tensor, didn't call init_view() yet
-                    init_view(alloc, node);
+            // allocate parents (leafs)
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * parent = node->src[j];
+                if (parent == NULL) {
+                    break;
                 }
+                allocate_node(galloc, parent);
             }
 
+            // allocate node
+            allocate_node(galloc, node);
+
+            AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
             for (int j = 0; j < GGML_MAX_SRC; j++) {
                 struct ggml_tensor * parent = node->src[j];
                 if (parent == NULL) {
                     break;
                 }
-                hash_get(ht, parent)->n_children += 1;
-                if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
-                    init_view(alloc, parent);
+                AT_PRINTF("%s", parent->name);
+                if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
+                    AT_PRINTF(", ");
                 }
             }
+            AT_PRINTF("\n");
         }
-    }
 
-    // allocate tensors
-    for (int g = 0; g < n_graphs; g++) {
-        struct ggml_cgraph * gf = graphs[g];
-        AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
-        // graph inputs are allocated first to ensure that they are not overwritten by each other
-        if (inputs != NULL && inputs[g] != NULL) {
-            for (int i = 0; inputs[g][i] != NULL; i++) {
-                struct ggml_tensor * input = inputs[g][i];
-                AT_PRINTF("input: %s\n", input->name);
-                allocate_node(alloc, input);
-            }
-        }
-        // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
-        int last_barrier_pos = 0;
-        int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes;
+        // update parents
+        // update immediately if there is no parse_seq
+        // update only at barriers if there is parse_seq
+        if ((parse_seq_len == 0) || parse_seq[ind] == -1) {
+            int update_start = parse_seq_len ? last_barrier_pos : ind;
+            int update_end   = parse_seq_len ? ind              : ind + 1;
+            for (int i = update_start; i < update_end; i++) {
+                int node_i = parse_seq_len ? parse_seq[i] : i;
+                struct ggml_tensor * node = gf->nodes[node_i];
 
-        for (int ind = 0; ind < n_nodes; ind++) {
-            // allocate a node if there is no parse_seq or this is not a barrier
-            if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) {
-                int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind;
-                struct ggml_tensor * node = gf->nodes[i];
-
-                // allocate parents (leafs)
                 for (int j = 0; j < GGML_MAX_SRC; j++) {
                     struct ggml_tensor * parent = node->src[j];
                     if (parent == NULL) {
                         break;
                     }
-                    allocate_node(alloc, parent);
-                }
+                    struct hash_node * p_hn = hash_get(galloc, parent);
+                    p_hn->n_children -= 1;
 
-                // allocate node
-                allocate_node(alloc, node);
+                    //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
 
-                AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
-                for (int j = 0; j < GGML_MAX_SRC; j++) {
-                    struct ggml_tensor * parent = node->src[j];
-                    if (parent == NULL) {
-                        break;
-                    }
-                    AT_PRINTF("%s", parent->name);
-                    if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
-                        AT_PRINTF(", ");
-                    }
-                }
-                AT_PRINTF("\n");
-            }
-
-            // update parents
-            // update immediately if there is no parse_seq
-            // update only at barriers if there is parse_seq
-            if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) {
-                int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
-                int update_end   = alloc->parse_seq_len ? ind              : ind + 1;
-                for (int i = update_start; i < update_end; i++) {
-                    int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i;
-                    struct ggml_tensor * node = gf->nodes[node_i];
-
-                    for (int j = 0; j < GGML_MAX_SRC; j++) {
-                        struct ggml_tensor * parent = node->src[j];
-                        if (parent == NULL) {
-                            break;
-                        }
-                        struct hash_node * p_hn = hash_get(ht, parent);
-                        p_hn->n_children -= 1;
-
-                        //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
-
-                        if (p_hn->n_children == 0 && p_hn->n_views == 0) {
-                            if (ggml_is_view(parent)) {
-                                struct ggml_tensor * view_src = parent->view_src;
-                                struct hash_node * view_src_hn = hash_get(ht, view_src);
-                                view_src_hn->n_views -= 1;
-                                AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
-                                if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
-                                    ggml_allocr_free_tensor(alloc, view_src);
-                                }
-                            }
-                            else {
-                                if (parent->data != node->data) {
-                                    ggml_allocr_free_tensor(alloc, parent);
-                                }
+                    if (p_hn->n_children == 0 && p_hn->n_views == 0) {
+                        if (ggml_is_view(parent)) {
+                            struct ggml_tensor * view_src = parent->view_src;
+                            struct hash_node * view_src_hn = hash_get(galloc, view_src);
+                            view_src_hn->n_views -= 1;
+                            AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
+                            if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0) {
+                                free_node(galloc, view_src);
                             }
                         }
+                        else {
+                            free_node(galloc, parent);
+                        }
                     }
                 }
-                AT_PRINTF("\n");
-                if (alloc->parse_seq_len) {
-                    last_barrier_pos = ind + 1;
-                }
             }
-        }
-        // free graph outputs here that wouldn't be freed otherwise because they have no children
-        if (outputs != NULL && outputs[g] != NULL) {
-            for (int i = 0; outputs[g][i] != NULL; i++) {
-                struct ggml_tensor * output = outputs[g][i];
-                AT_PRINTF("output: %s\n", output->name);
-                ggml_allocr_free_tensor(alloc, output);
+            AT_PRINTF("\n");
+            if (parse_seq_len) {
+                last_barrier_pos = ind + 1;
             }
         }
     }
+}
 
-    return alloc->max_size;
+size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph) {
+    size_t hash_size = graph->visited_hash_table.size;
+
+    // check if the hash table is initialized and large enough
+    if (galloc->hash_set.size < hash_size) {
+        if (galloc->hash_set.keys != NULL) {
+            free(galloc->hash_set.keys);
+        }
+        if (galloc->hash_values != NULL) {
+            free(galloc->hash_values);
+        }
+        galloc->hash_set.keys = malloc(sizeof(struct ggml_tensor *) * hash_size);
+        galloc->hash_set.size = hash_size;
+        galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
+    }
+
+    // reset hash table
+    memset(galloc->hash_set.keys, 0, sizeof(struct ggml_tensor *) * hash_size);
+    memset(galloc->hash_values,   0, sizeof(struct hash_node) * hash_size);
+
+    galloc->talloc = talloc;
+    ggml_tallocr_alloc_graph_impl(galloc, graph);
+    galloc->talloc = NULL;
+
+    size_t max_size = ggml_tallocr_max_size(talloc);
+
+    return max_size;
 }
 
-size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
-    return ggml_allocr_alloc_graph_n(alloc, &graph, 1, NULL, NULL);
+void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_alloct) {
+    const size_t hash_size = hash_set.size;
+
+    GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs));
+
+    galloc->talloc = NULL;
+
+    // alloc hash_values if needed
+    if (galloc->hash_values == NULL || galloc->hash_values_size < hash_size) {
+        free(galloc->hash_values);
+        galloc->hash_values      = malloc(sizeof(struct hash_node) * hash_size);
+        galloc->hash_values_size = hash_size;
+    }
+
+    // free hash_set.keys if needed
+    if (galloc->hash_set.keys != NULL) {
+        free(galloc->hash_set.keys);
+    }
+    galloc->hash_set = hash_set;
+
+    // reset hash values
+    memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
+
+    galloc->hash_allocs = hash_node_alloct;
+
+    ggml_tallocr_alloc_graph_impl(galloc, graph);
+
+    // remove unowned resources
+    galloc->hash_set.keys = NULL;
+    galloc->hash_allocs = NULL;
 }
 
-size_t ggml_allocr_max_size(struct ggml_allocr * alloc) {
-    return alloc->max_size;
+// legacy API wrapper
+
+struct ggml_allocr {
+    ggml_tallocr_t talloc;
+    ggml_gallocr_t galloc;
+};
+
+static ggml_allocr_t ggml_allocr_new_impl(ggml_tallocr_t talloc) {
+    ggml_allocr_t alloc = (ggml_allocr_t)malloc(sizeof(struct ggml_allocr));
+    *alloc = (struct ggml_allocr) {
+        /*.talloc = */ talloc,
+        /*.galloc = */ ggml_gallocr_new(),
+    };
+    return alloc;
+}
+
+ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment) {
+    return ggml_allocr_new_impl(ggml_tallocr_new(data, size, alignment));
+}
+
+ggml_allocr_t ggml_allocr_new_measure(size_t alignment) {
+    return ggml_allocr_new_impl(ggml_tallocr_new_measure(alignment));
+}
+
+ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
+    return ggml_allocr_new_impl(ggml_tallocr_new_from_buffer(buffer));
+}
+
+ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size) {
+    return ggml_allocr_new_impl(ggml_tallocr_new_from_backend(backend, size));
+}
+
+ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend) {
+    return ggml_allocr_new_impl(ggml_tallocr_new_measure_from_backend(backend));
+}
+
+struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc) {
+    return ggml_tallocr_get_buffer(alloc->talloc);
+}
+
+void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) {
+    ggml_gallocr_set_parse_seq(alloc->galloc, list, n);
+}
+
+void ggml_allocr_free(ggml_allocr_t alloc) {
+    ggml_gallocr_free(alloc->galloc);
+    ggml_tallocr_free(alloc->talloc);
+    free(alloc);
+}
+
+bool ggml_allocr_is_measure(ggml_allocr_t alloc) {
+    return ggml_tallocr_is_measure(alloc->talloc);
+}
+
+void ggml_allocr_reset(ggml_allocr_t alloc) {
+    ggml_tallocr_reset(alloc->talloc);
+}
+
+void ggml_allocr_alloc(ggml_allocr_t alloc, struct ggml_tensor * tensor) {
+    ggml_tallocr_alloc(alloc->talloc, tensor);
+}
+
+size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
+    return ggml_tallocr_max_size(alloc->talloc);
+}
+
+size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
+    return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
 }
diff --git a/src/ggml-backend-impl.h b/src/ggml-backend-impl.h
new file mode 100644 (file)
index 0000000..211e3d4
--- /dev/null
@@ -0,0 +1,87 @@
+#pragma once
+
+// ggml-backend internal header
+
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    //
+    // Backend buffer
+    //
+
+    typedef void * ggml_backend_buffer_context_t;
+
+    struct ggml_backend_buffer_i {
+        void   (*free_buffer)   (ggml_backend_buffer_t buffer);
+        void * (*get_base)      (ggml_backend_buffer_t buffer); // get base pointer
+        size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
+        void   (*init_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
+        void   (*free_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
+    };
+
+    struct ggml_backend_buffer {
+        struct ggml_backend_buffer_i iface;
+
+        ggml_backend_t                backend;
+        ggml_backend_buffer_context_t context;
+
+        size_t size;
+    };
+
+    GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
+            struct ggml_backend                  * backend,
+            struct ggml_backend_buffer_i           iface,
+                   ggml_backend_buffer_context_t   context,
+                   size_t                          size);
+
+    //
+    // Backend
+    //
+
+    typedef void * ggml_backend_context_t;
+
+    struct ggml_backend_i {
+        const char * (*get_name)(ggml_backend_t backend);
+
+        void (*free)(ggml_backend_t backend);
+
+        // buffer allocation
+        ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
+
+        // get buffer alignment
+        size_t (*get_alignment)(ggml_backend_t backend);
+
+        // tensor data access
+        // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
+        void (*set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        void (*synchronize)     (ggml_backend_t backend);
+
+        // (optional) copy tensor between different backends, allow for single-copy tranfers
+        void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+        void (*cpy_tensor_to)  (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+
+        // compute graph with a plan
+        ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+        void                      (*graph_plan_free)   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+        void                      (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+
+        // compute graph without a plan
+        void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+
+        // check if the backend supports an operation
+        bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
+    };
+
+    struct ggml_backend {
+        struct ggml_backend_i iface;
+
+        ggml_backend_context_t context;
+    };
+
+#ifdef  __cplusplus
+}
+#endif
index ca8d83dafe47c9763b7f648b9d26bd4e6dfb985e..00f78432c3d17f2b92278929865015d6be48dc3e 100644 (file)
@@ -1,7 +1,9 @@
-#include "ggml-backend.h"
+#include "ggml-backend-impl.h"
 #include "ggml-alloc.h"
+#include "ggml-impl.h"
 
 #include <assert.h>
+#include <limits.h>
 #include <stdarg.h>
 #include <stdio.h>
 #include <stdlib.h>
@@ -33,6 +35,10 @@ ggml_backend_buffer_t ggml_backend_buffer_init(
 }
 
 void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
+    if (buffer == NULL) {
+        return;
+    }
+
     if (buffer->iface.free_buffer != NULL) {
         buffer->iface.free_buffer(buffer);
     }
@@ -43,15 +49,20 @@ size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) {
     return ggml_backend_get_alignment(buffer->backend);
 }
 
-void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
-    return buffer->iface.get_base(buffer);
-}
-
 size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
     return buffer->size;
 }
 
+void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
+    void * base = buffer->iface.get_base(buffer);
+
+    GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
+
+    return base;
+}
+
 size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    // get_alloc_size is optional, defaults to ggml_nbytes
     if (buffer->iface.get_alloc_size) {
         return buffer->iface.get_alloc_size(buffer, tensor);
     }
@@ -59,12 +70,14 @@ size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct g
 }
 
 void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    // init_tensor is optional
     if (buffer->iface.init_tensor) {
         buffer->iface.init_tensor(buffer, tensor);
     }
 }
 
 void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    // free_tensor is optional
     if (buffer->iface.free_tensor) {
         buffer->iface.free_tensor(buffer, tensor);
     }
@@ -73,14 +86,21 @@ void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_t
 // backend
 
 ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) {
-    return tensor->buffer->backend;
+    return tensor->buffer ? tensor->buffer->backend : NULL;
 }
 
 const char * ggml_backend_name(ggml_backend_t backend) {
+    if (backend == NULL) {
+        return "NULL";
+    }
     return backend->iface.get_name(backend);
 }
 
 void ggml_backend_free(ggml_backend_t backend) {
+    if (backend == NULL) {
+        return;
+    }
+
     backend->iface.free(backend);
 }
 
@@ -101,13 +121,23 @@ void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * dat
 }
 
 void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
-    ggml_get_backend(tensor)->iface.synchronize(ggml_get_backend(tensor));
+    ggml_backend_t backend = ggml_get_backend(tensor);
+
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(backend != NULL && "tensor backend not set");
+
+    backend->iface.set_tensor_async(backend, tensor, data, offset, size);
+    backend->iface.synchronize(backend);
 }
 
 void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
-    ggml_get_backend(tensor)->iface.synchronize(ggml_get_backend(tensor));
+    ggml_backend_t backend = ggml_get_backend(tensor);
+
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(backend != NULL && "tensor backend not set");
+
+    backend->iface.get_tensor_async(backend, tensor, data, offset, size);
+    backend->iface.synchronize(backend);
 }
 
 void ggml_backend_synchronize(ggml_backend_t backend) {
@@ -156,7 +186,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
     //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]);
     GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
 
-    // printf("cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src));
+    // fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src));
 
     if (src == dst) {
         return;
@@ -234,6 +264,8 @@ static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backen
     size += TENSOR_ALIGNMENT;   // malloc may return an address that is not aligned
     void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
 
+    GGML_ASSERT(data != NULL && "failed to allocate buffer");
+
     return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size);
 }
 
@@ -271,8 +303,7 @@ static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_t backend, struct ggml
 }
 
 static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    // for a backend such as CUDA that can queue async calls, it is ok to do this asynchronously, but it may not be the case for other backends
-    ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
+    ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
 
     UNUSED(backend);
 }
@@ -383,3 +414,533 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
 ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) {
     return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size);
 }
+
+// scheduler
+
+#define GGML_MAX_BACKENDS 4
+#define GGML_MAX_SPLITS 256
+#define GGML_MAX_SPLIT_INPUTS 16
+
+struct ggml_backend_sched_split {
+    ggml_tallocr_t tallocr;
+    int i_start;
+    int i_end;
+    struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
+    int n_inputs;
+    struct ggml_cgraph * graph;
+};
+
+struct ggml_backend_sched {
+    int n_backends;
+    ggml_backend_t backends[GGML_MAX_BACKENDS];
+    ggml_tallocr_t  tallocs[GGML_MAX_BACKENDS];
+
+    ggml_gallocr_t galloc;
+
+    struct ggml_hash_set    hash_set;
+    ggml_tallocr_t *        node_talloc;                     // [hash_set.size]
+    struct ggml_tensor * (* node_copies)[GGML_MAX_BACKENDS]; // [hash_set.size][GGML_MAX_BACKENDS]
+
+    struct ggml_cgraph * graph;
+    struct ggml_backend_sched_split splits[GGML_MAX_SPLITS];
+    int n_splits;
+
+    struct ggml_context * ctx;
+
+    // align context_buffer to GGML_MEM_ALIGN
+    __attribute__((aligned(GGML_MEM_ALIGN)))
+    char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + GGML_MAX_SPLITS*sizeof(struct ggml_cgraph)];
+};
+
+#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
+#define node_allocr(node) sched->node_talloc[hash_id(node)]
+
+static bool ggml_is_view_op(enum ggml_op op) {
+    return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
+}
+
+// returns the priority of the backend, lower is better
+static int sched_backend_prio(ggml_backend_sched_t sched, ggml_backend_t backend) {
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (sched->backends[i] == backend) {
+            return i;
+        }
+    }
+    return INT_MAX;
+}
+
+static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (sched->tallocs[i] == allocr) {
+            return i;
+        }
+    }
+    return INT_MAX;
+}
+
+// returns the backend that should be used for the node based on the current locations
+char causes[GGML_DEFAULT_GRAPH_SIZE*4 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
+static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) {
+    // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
+    // ie. kv cache updates
+    // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
+    // dst
+    ggml_backend_t cur_backend = ggml_get_backend(node);
+    if (cur_backend != NULL) {
+        sprintf(causes[hash_id(node)], "1.dst");
+        return cur_backend;
+    }
+
+    // view_src
+    if (node->view_src != NULL && ggml_get_backend(node->view_src) != NULL) {
+        sprintf(causes[hash_id(node)], "1.vsrc");
+        return ggml_get_backend(node->view_src);
+    }
+
+    // src
+    int cur_prio = INT_MAX;
+    size_t cur_size = 0;
+
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        const struct ggml_tensor * src = node->src[i];
+        if (src == NULL) {
+            break;
+        }
+        ggml_backend_t src_backend = ggml_get_backend(src);
+        if (src_backend != NULL) {
+            int src_prio = sched_backend_prio(sched, src_backend);
+            size_t src_size = ggml_nbytes(src);
+            if (src_prio < cur_prio && src_size >= cur_size) {
+                cur_prio = src_prio;
+                cur_size = src_size;
+                cur_backend = src_backend;
+                sprintf(causes[hash_id(node)], "1.src%d", i);
+            }
+        }
+    }
+    return cur_backend;
+}
+
+static char * fmt_size(size_t size) {
+    static char buffer[128];
+    if (size >= 1024*1024) {
+        sprintf(buffer, "%zuM", size/1024/1024);
+    } else {
+        sprintf(buffer, "%zuK", size/1024);
+    }
+    return buffer;
+}
+
+static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    int cur_split = 0;
+    for (int i = 0; i < graph->n_nodes; i++) {
+        if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
+            ggml_backend_t split_backend = ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend;
+            fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs);
+            for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
+                fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
+            }
+            fprintf(stderr, "\n");
+            cur_split++;
+        }
+        struct ggml_tensor * node = graph->nodes[i];
+        if (ggml_is_view_op(node->op)) {
+            continue;
+        }
+        ggml_tallocr_t node_allocr = node_allocr(node);
+        ggml_backend_t node_backend = node_allocr ? ggml_tallocr_get_buffer(node_allocr)->backend : NULL;
+        fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]);
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                break;
+            }
+            ggml_tallocr_t src_allocr = node_allocr(src);
+            ggml_backend_t src_backend = src_allocr ? ggml_tallocr_get_buffer(src_allocr)->backend : NULL;
+            fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]);
+        }
+        fprintf(stderr, "\n");
+    }
+}
+
+// creates a copy of the tensor with the same memory layout
+static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {
+    struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        dup->nb[i] = tensor->nb[i];
+    }
+    return dup;
+}
+
+// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
+// TODO: merge passes
+static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    // reset state
+    size_t hash_size = sched->hash_set.size;
+    memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
+    memset(sched->node_talloc,   0, sizeof(sched->node_talloc[0])   * hash_size);
+    memset(sched->node_copies,   0, sizeof(sched->node_copies[0])   * hash_size);
+    sched->n_splits = 0;
+
+    struct ggml_init_params params = {
+        /*.mem_size =   */ sizeof(sched->context_buffer),
+        /*.mem_buffer = */ sched->context_buffer,
+        /*.no_alloc =   */ true
+    };
+
+    if (sched->ctx != NULL) {
+        ggml_free(sched->ctx);
+    }
+
+    sched->ctx = ggml_init(params);
+
+    // pass 1: assign backends to ops with allocated inputs
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        if (node_allocr(leaf) != NULL) {
+            // do not overwrite user assignments
+            continue;
+        }
+        ggml_backend_t leaf_backend = ggml_get_backend(leaf);
+        if (leaf_backend == NULL && leaf->view_src != NULL) {
+            leaf_backend = ggml_get_backend(leaf->view_src);
+        }
+        if (leaf_backend != NULL) {
+            node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend);
+        }
+    }
+
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        if (node_allocr(node) != NULL) {
+            // do not overwrite user assignments
+            continue;
+        }
+        ggml_backend_t node_backend = sched_backend_from_cur(sched, node);
+        if (node_backend != NULL) {
+            node_allocr(node) = ggml_backend_sched_get_tallocr(sched, node_backend);
+        }
+    }
+    //printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+
+    // pass 2: assign backends to ops from current assignments
+    // TODO:
+    //  - reuse sched_backend_from_cur
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        ggml_tallocr_t node_allocr = node_allocr(node);
+        if (node_allocr == NULL) {
+            int    cur_prio = INT_MAX;
+            size_t cur_size = 0;
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * src = node->src[j];
+                if (src == NULL) {
+                    break;
+                }
+                ggml_tallocr_t src_allocr = node_allocr(src);
+                if (src_allocr != NULL) {
+                    int    src_prio = sched_allocr_prio(sched, src_allocr);
+                    size_t src_size = ggml_nbytes(src);
+                    if (src_prio < cur_prio && src_size >= cur_size) {
+                        cur_prio = src_prio;
+                        cur_size = src_size;
+                        node_allocr = src_allocr;
+                        sprintf(causes[hash_id(node)], "2.src%d", j);
+                    }
+                }
+            }
+            if (node_allocr != NULL) {
+                node_allocr(node) = node_allocr;
+            }
+        }
+    }
+    //printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+
+    // pass 3: assign backends to remaining src from dst (should only be leafs)
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        ggml_tallocr_t node_allocr = node_allocr(node);
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                break;
+            }
+            ggml_tallocr_t src_allocr = node_allocr(src);
+            if (src_allocr == NULL) {
+                node_allocr(src) = node_allocr;
+            }
+        }
+    }
+    //printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+
+    // pass 4: split graph, find tensors that need to be copied
+    // TODO:
+    //  - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost
+    // find first backend
+    int cur_split = 0;
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        if (node->view_src == NULL) {
+            sched->splits[0].tallocr = node_allocr(node);
+            break;
+        }
+    }
+    sched->splits[0].i_start = 0;
+    sched->splits[0].n_inputs = 0;
+    memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK
+    ggml_tallocr_t cur_allocr = sched->splits[0].tallocr;
+    size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr);
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+
+        if (ggml_is_view_op(node->op)) {
+            continue;
+        }
+
+        ggml_tallocr_t node_allocr = node_allocr(node);
+
+        if (node_allocr != cur_allocr) {
+            sched->splits[cur_split].i_end = i;
+            cur_split++;
+            GGML_ASSERT(cur_split < GGML_MAX_SPLITS);
+            sched->splits[cur_split].tallocr = node_allocr;
+            sched->splits[cur_split].i_start = i;
+            sched->splits[cur_split].n_inputs = 0;
+            memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK
+            cur_allocr = node_allocr;
+            cur_backend_id = sched_allocr_prio(sched, cur_allocr);
+        }
+
+        // find inputs that are not on the same backend
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                break;
+            }
+            ggml_tallocr_t src_allocr = node_allocr(src);
+            if (src_allocr != node_allocr) {
+                int n_inputs = sched->splits[cur_split].n_inputs++;
+                GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
+                sched->splits[cur_split].inputs[n_inputs] = (struct ggml_tensor *)src;
+
+                // create copies
+                size_t id = hash_id(src);
+                if (sched->node_copies[id][cur_backend_id] == NULL) {
+                    struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+                    sched->node_copies[id][cur_backend_id] = tensor_copy;
+                    node_allocr(tensor_copy) = cur_allocr;
+                    ggml_backend_t backend = ggml_tallocr_get_buffer(cur_allocr)->backend;
+                    ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
+                }
+                node->src[j] = sched->node_copies[id][cur_backend_id];
+            }
+        }
+    }
+    sched->splits[cur_split].i_end = graph->n_nodes;
+    sched->n_splits = cur_split + 1;
+
+    //fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout);
+
+#if 1
+    // sanity check: all sources should have the same backend as the node
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        ggml_tallocr_t node_allocr = node_allocr(node);
+        if (node_allocr == NULL) {
+            fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
+        }
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                break;
+            }
+            ggml_tallocr_t src_allocr = node_allocr(src);
+            if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
+                fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
+                    node->name, node_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL",
+                    j, src->name, src_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL");
+            }
+        }
+    }
+#endif
+
+    // create copies of the graph for each split
+    // FIXME: avoid this copy, pass split inputs to ggml_gallocr_alloc_graph_n in some other way
+    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false);
+    for (int i = 0; i < sched->n_splits; i++) {
+        struct ggml_backend_sched_split * split = &sched->splits[i];
+        split->graph = ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end);
+
+        // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
+        for (int j = 0; j < split->n_inputs; j++) {
+            struct ggml_tensor * input = split->inputs[j];
+            struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)];
+            input_cpy->src[0] = input;
+            graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
+        }
+
+        for (int j = split->i_start; j < split->i_end; j++) {
+            graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
+        }
+    }
+    sched->graph = graph_copy;
+}
+
+static void sched_alloc_splits(ggml_backend_sched_t sched) {
+    ggml_gallocr_alloc_graph_n(
+        sched->galloc,
+        sched->graph,
+        sched->hash_set,
+        sched->node_talloc);
+}
+
+static void sched_compute_splits(ggml_backend_sched_t sched) {
+    uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
+    uint64_t compute_us[GGML_MAX_BACKENDS] = {0};
+
+    struct ggml_backend_sched_split * splits = sched->splits;
+
+    for (int i = 0; i < sched->n_splits; i++) {
+        struct ggml_backend_sched_split * split = &splits[i];
+        ggml_backend_t split_backend = ggml_tallocr_get_buffer(split->tallocr)->backend;
+        int split_backend_id = sched_backend_prio(sched, split_backend);
+
+        // copy the input tensors to the split backend
+        uint64_t copy_start_us = ggml_time_us();
+        for (int j = 0; j < split->n_inputs; j++) {
+            struct ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)];
+            if (split->inputs[j]->buffer == NULL) {
+                if (split->inputs[j]->view_src == NULL) {
+                    fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name);
+                    exit(1);
+                }
+                struct ggml_tensor * view = split->inputs[j];
+                view->backend = view->view_src->backend;
+                view->buffer  = view->view_src->buffer;
+                view->data    = (char *)view->view_src->data + view->view_offs;
+                ggml_backend_buffer_init_tensor(ggml_backend_sched_get_buffer(sched, view->buffer->backend), view);
+            }
+            if (input_cpy->buffer == NULL) {
+                fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
+                exit(1);
+            }
+            GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend);
+            GGML_ASSERT(input_cpy->buffer->backend == split_backend);
+            ggml_backend_tensor_copy(split->inputs[j], input_cpy);
+        }
+        // ggml_backend_synchronize(split_backend);
+        int64_t copy_end_us = ggml_time_us();
+        copy_us[split_backend_id] += copy_end_us - copy_start_us;
+
+#if 0
+        char split_filename[GGML_MAX_NAME];
+        snprintf(split_filename, GGML_MAX_NAME, "split_%i_%s.dot", i, ggml_backend_name(split_backend));
+        ggml_graph_dump_dot(split->graph, NULL, split_filename);
+#endif
+
+        uint64_t compute_start_us = ggml_time_us();
+        ggml_backend_graph_compute(split_backend, split->graph);
+        // ggml_backend_synchronize(split_backend);
+        uint64_t compute_end_us = ggml_time_us();
+        compute_us[split_backend_id] += compute_end_us - compute_start_us;
+    }
+
+#if 0
+    // per-backend timings
+    fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits);
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (copy_us[i] > 0 || compute_us[i] > 0) {
+            fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]);
+        }
+    }
+#endif
+}
+
+static void sched_reset(ggml_backend_sched_t sched) {
+    for (int i = 0; i < sched->n_backends; i++) {
+        ggml_tallocr_reset(sched->tallocs[i]);
+    }
+}
+
+ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends) {
+    GGML_ASSERT(n_backends <= GGML_MAX_BACKENDS);
+
+    struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched));
+    memset(sched, 0, sizeof(struct ggml_backend_sched));
+
+    fprintf(stderr, "ggml_backend_sched size: %lu KB\n", sizeof(struct ggml_backend_sched)/1024);
+
+    sched->n_backends = n_backends;
+    for (int i = 0; i < n_backends; i++) {
+        sched->backends[i] = backends[i];
+    }
+
+    sched->galloc = ggml_gallocr_new();
+
+    // init measure allocs for each backend
+    for (int i = 0; i < n_backends; i++) {
+        sched->tallocs[i] = ggml_tallocr_new_measure_from_backend(backends[i]);
+    }
+
+    return sched;
+}
+
+void ggml_backend_sched_free(ggml_backend_sched_t sched) {
+    if (sched == NULL) {
+        return;
+    }
+    for (int i = 0; i < sched->n_backends; i++) {
+        ggml_tallocr_free(sched->tallocs[i]);
+    }
+    ggml_gallocr_free(sched->galloc);
+    free(sched->hash_set.keys);
+    free(sched->node_talloc);
+    free(sched->node_copies);
+    free(sched);
+}
+
+void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
+    // initialize hash tables
+    size_t hash_size = measure_graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS;
+    sched->hash_set.size = hash_size;
+    sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size);
+    sched->node_talloc   = malloc(sizeof(sched->node_talloc[0])   * hash_size);
+    sched->node_copies   = malloc(sizeof(sched->node_copies[0])   * hash_size);
+
+    sched_split_graph(sched, measure_graph);
+    sched_alloc_splits(sched);
+
+    // allocate buffers and reset allocators
+    for (int i = 0; i < sched->n_backends; i++) {
+        size_t size = ggml_tallocr_max_size(sched->tallocs[i]);
+        ggml_tallocr_free(sched->tallocs[i]);
+        sched->tallocs[i] = ggml_tallocr_new_from_backend(sched->backends[i], size);
+    }
+
+    sched_reset(sched);
+}
+
+void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
+
+    sched_split_graph(sched, graph);
+    sched_alloc_splits(sched);
+    sched_compute_splits(sched);
+    sched_reset(sched);
+}
+
+ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend) {
+    int backend_index = sched_backend_prio(sched, backend);
+    return sched->tallocs[backend_index];
+}
+
+ggml_backend_buffer_t ggml_backend_sched_get_buffer(ggml_backend_sched_t sched, ggml_backend_t backend) {
+    int backend_index = sched_backend_prio(sched, backend);
+    return ggml_tallocr_get_buffer(sched->tallocs[backend_index]);
+}
+
+void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
+    int backend_index = sched_backend_prio(sched, backend);
+    GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
+    node_allocr(node) = sched->tallocs[backend_index];
+}
index d1e874b6c778af32d887cc5cf5e8970ceb92310f..92eb547f6114aa68e0c8f7907c3660e54a4bea8a 100644 (file)
@@ -81,6 +81,7 @@
 
 #include "ggml-cuda.h"
 #include "ggml.h"
+#include "ggml-backend-impl.h"
 
 #define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
 #define CC_VOLTA      700
@@ -7449,11 +7450,11 @@ static size_t g_temp_tensor_extra_index = 0;
 
 static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
     if (g_temp_tensor_extras == nullptr) {
-        g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES];
+        g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
     }
 
     size_t alloc_index = g_temp_tensor_extra_index;
-    g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_MAX_NODES;
+    g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
     ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
     memset(extra, 0, sizeof(*extra));
 
@@ -7764,11 +7765,11 @@ struct ggml_backend_buffer_context_cuda {
 
     ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
         if (temp_tensor_extras == nullptr) {
-            temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES];
+            temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
         }
 
         size_t alloc_index = temp_tensor_extra_index;
-        temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_MAX_NODES;
+        temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
         ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
         memset(extra, 0, sizeof(*extra));
 
@@ -7851,10 +7852,13 @@ static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
 };
 
 static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) {
-    ggml_cuda_set_device(g_main_device);
-
     ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda;
+
+    size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
+
+    ggml_cuda_set_device(g_main_device);
     CUDA_CHECK(cudaMalloc(&ctx->device, size));
+
     return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size);
 }
 
@@ -7921,6 +7925,8 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
+        if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE)
+            continue;
         assert(node->backend == GGML_BACKEND_GPU);
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             if (node->src[j] != nullptr) {
diff --git a/src/ggml-impl.h b/src/ggml-impl.h
new file mode 100644 (file)
index 0000000..e5a2d9d
--- /dev/null
@@ -0,0 +1,31 @@
+#pragma once
+
+#include "ggml.h"
+
+// GGML internal header
+
+#include <stddef.h>
+#include <stdbool.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define GGML_HASHTABLE_FULL ((size_t)-1)
+#define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
+
+bool   ggml_hash_contains      (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
+size_t ggml_hash_find          (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
+size_t ggml_hash_insert        (      struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+// return index, asserts if table is full
+size_t ggml_hash_find_or_insert(      struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+#ifdef __cplusplus
+}
+#endif
index c1901dca75269dbc86971b5dccccaaae8e2c871c..12f66237474bef71840d36c9b110c6392ed7337e 100644 (file)
@@ -1,5 +1,6 @@
 #import "ggml-metal.h"
 
+#import "ggml-backend-impl.h"
 #import "ggml.h"
 
 #import <Foundation/Foundation.h>
@@ -23,7 +24,7 @@
 
 #define UNUSED(x) (void)(x)
 
-#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
+#define GGML_MAX_CONCUR (2*GGML_DEFAULT_GRAPH_SIZE)
 
 struct ggml_metal_buffer {
     const char * name;
@@ -737,6 +738,21 @@ void ggml_metal_graph_compute(
                 struct ggml_tensor * src1 = gf->nodes[i]->src[1];
                 struct ggml_tensor * dst  = gf->nodes[i];
 
+                switch (dst->op) {
+                    case GGML_OP_NONE:
+                    case GGML_OP_RESHAPE:
+                    case GGML_OP_VIEW:
+                    case GGML_OP_TRANSPOSE:
+                    case GGML_OP_PERMUTE:
+                        {
+                            // noop
+                            continue;
+                        } break;
+                    default:
+                        {
+                        } break;
+                }
+
                 const int64_t  ne00 = src0 ? src0->ne[0] : 0;
                 const int64_t  ne01 = src0 ? src0->ne[1] : 0;
                 const int64_t  ne02 = src0 ? src0->ne[2] : 0;
@@ -790,14 +806,6 @@ void ggml_metal_graph_compute(
                 //}
 
                 switch (dst->op) {
-                    case GGML_OP_NONE:
-                    case GGML_OP_RESHAPE:
-                    case GGML_OP_VIEW:
-                    case GGML_OP_TRANSPOSE:
-                    case GGML_OP_PERMUTE:
-                        {
-                            // noop
-                        } break;
                     case GGML_OP_CONCAT:
                         {
                             const int64_t nb = ne00;
index 5b1e8c6286d22d2a357b276798d435ca09b36a4d..f01363cefb3b62797215486426708053bd603533 100644 (file)
@@ -1,6 +1,7 @@
 #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
 
 #include "ggml.h"
+#include "ggml-impl.h"
 
 #ifdef GGML_USE_K_QUANTS
 #include "k_quants.h"
@@ -126,6 +127,44 @@ typedef void * thread_ret_t;
 #endif
 #endif
 
+#if defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
+
+#include <sys/wait.h>
+
+void ggml_print_backtrace(void) {
+    /*
+    #include <execinfo.h>
+    #include <dlfcn.h>
+
+    void * trace[100];
+
+    int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
+
+    backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);
+    */
+
+    // backtrack_symbols does not show line numbers, use gdb instead
+    char attach[32];
+    snprintf(attach, sizeof(attach), "attach %d", getpid());
+    int pid = fork();
+    if (pid == 0) {
+        execlp("gdb", "gdb", "--batch",
+            "-ex", "set style enabled on",
+            "-ex", attach,
+            "-ex", "bt -frame-info source-and-location",
+            "-ex", "detach",
+            "-ex", "quit",
+            NULL);
+    } else {
+        waitpid(pid, NULL, 0);
+    }
+}
+#else
+void ggml_print_backtrace(void) {
+    // platform not supported
+}
+#endif
+
 /*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
@@ -17232,62 +17271,109 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
 
 ////////////////////////////////////////////////////////////////////////////////
 
-static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
+static size_t ggml_hash_size(size_t min_sz) {
+    // next primes after powers of two
+    static const size_t primes[] = {
+        2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031,
+        2053, 4099, 8209, 16411, 32771, 65537, 131101,
+        262147, 524309, 1048583, 2097169, 4194319, 8388617,
+        16777259, 33554467, 67108879, 134217757, 268435459,
+        536870923, 1073741827, 2147483659
+    };
+    static const size_t n_primes = sizeof(primes)/sizeof(primes[0]);
+
+    // find the smallest prime that is larger or equal to min_sz
+    size_t l = 0;
+    size_t r = n_primes;
+    while (l < r) {
+        size_t m = (l + r)/2;
+        if (primes[m] < min_sz) {
+            l = m + 1;
+        } else {
+            r = m;
+        }
+    }
+    size_t sz = l < n_primes ? primes[l] : min_sz | 1;
+    return sz;
+}
 
-static size_t hash(void * p) {
-    return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
+static size_t ggml_hash(const void * p) {
+    return (size_t)p;
 }
 
-static size_t hash_find(void * hash_table[], void * p) {
-    size_t h = hash(p);
+size_t ggml_hash_find(const struct ggml_hash_set hash_set, struct ggml_tensor * key) {
+    size_t h = ggml_hash(key) % hash_set.size;
 
     // linear probing
     size_t i = h;
-    while (hash_table[i] != NULL && hash_table[i] != p) {
-        i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
+    while (hash_set.keys[i] != NULL && hash_set.keys[i] != key) {
+        i = (i + 1) % hash_set.size;
         if (i == h) {
             // visited all hash table entries -> not found
-            return GGML_GRAPH_HASHTABLE_SIZE;
+            return GGML_HASHTABLE_FULL;
         }
     }
     return i;
 }
 
-static bool hash_insert(void * hash_table[], void * p) {
-    size_t i = hash_find(hash_table, p);
+bool ggml_hash_contains(struct ggml_hash_set hash_set, struct ggml_tensor * key) {
+    size_t i = ggml_hash_find(hash_set, key);
+    return i != GGML_HASHTABLE_FULL && hash_set.keys[i] == key;
+}
 
-    GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+size_t ggml_hash_insert(struct ggml_hash_set hash_set, struct ggml_tensor * key) {
+    size_t i = ggml_hash_find(hash_set, key);
 
-    if (hash_table[i] == p) {
-        return true;
+    GGML_ASSERT(i != GGML_HASHTABLE_FULL);
+
+    if (hash_set.keys[i] == key) {
+        return GGML_HASHTABLE_ALREADY_EXISTS;
     }
 
     // insert
-    GGML_ASSERT(hash_table[i] == NULL);
-    hash_table[i] = p;
-    return false;
+    GGML_ASSERT(hash_set.keys[i] == NULL);
+    hash_set.keys[i] = key;
+    return i;
 }
 
-static bool hash_contains(void * hash_table[], void * p) {
-    size_t i = hash_find(hash_table, p);
-    return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
+size_t ggml_hash_find_or_insert(struct ggml_hash_set hash_set, struct ggml_tensor * key) {
+    size_t i = ggml_hash_find(hash_set, key);
+
+    GGML_ASSERT(i != GGML_HASHTABLE_FULL);
+
+    hash_set.keys[i] = key;
+    return i;
+}
+
+static struct ggml_hash_set ggml_hash_set_new(size_t size) {
+    size = ggml_hash_size(size);
+    struct ggml_hash_set result;
+    result.size = size;
+    result.keys = malloc(sizeof(struct ggml_tensor *) * size);
+    memset(result.keys, 0, sizeof(struct ggml_tensor *) * size);
+    return result;
+}
+
+static void ggml_hash_set_free(struct ggml_hash_set hash_set) {
+    free(hash_set.keys);
 }
 
 struct hash_map {
-    void * keys[GGML_GRAPH_HASHTABLE_SIZE];
-    void * vals[GGML_GRAPH_HASHTABLE_SIZE];
+    struct ggml_hash_set set;
+    struct ggml_tensor ** vals;
 };
 
-static struct hash_map * new_hash_map(void) {
+static struct hash_map * ggml_new_hash_map(size_t size) {
     struct hash_map * result = malloc(sizeof(struct hash_map));
-    for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
-        result->keys[i] = NULL;
-        result->vals[i] = NULL;
-    }
+    result->set = ggml_hash_set_new(size);
+    result->vals = malloc(sizeof(struct ggml_tensor *) * result->set.size);
+    memset(result->vals, 0, sizeof(struct ggml_tensor *) * result->set.size);
     return result;
 }
 
-static void free_hash_map(struct hash_map * map) {
+static void ggml_hash_map_free(struct hash_map * map) {
+    ggml_hash_set_free(map->set);
+    free(map->vals);
     free(map);
 }
 
@@ -17307,7 +17393,7 @@ static struct ggml_tensor * ggml_recompute_graph_node(
         return node;
     }
 
-    if (!hash_contains(graph->visited_hash_table, node)) {
+    if (!ggml_hash_contains(graph->visited_hash_table, node)) {
         return node;
     }
 
@@ -17322,17 +17408,17 @@ static struct ggml_tensor * ggml_recompute_graph_node(
         return node;
     }
 
-    size_t i = hash_find(replacements->keys, node);
-    GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
-    if (replacements->keys[i] == node) {
-        return (struct ggml_tensor *) replacements->vals[i];
+    size_t i = ggml_hash_find(replacements->set, node);
+    GGML_ASSERT(i != GGML_HASHTABLE_FULL); // assert that not full
+    if (replacements->set.keys[i] == node) {
+        return replacements->vals[i];
     }
 
     struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
 
     // insert clone into replacements
-    GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
-    replacements->keys[i] = node;
+    GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite
+    replacements->set.keys[i] = node;
     replacements->vals[i] = clone;
 
     clone->op       = node->op;
@@ -17369,26 +17455,26 @@ void ggml_build_backward_gradient_checkpointing(
         struct ggml_cgraph    * gb_tmp,
         struct ggml_tensor  * * checkpoints,
         int                     n_checkpoints) {
-    *gb_tmp = *gf;
+    ggml_graph_cpy(gf, gb_tmp);
     ggml_build_backward_expand(ctx, gf, gb_tmp, true);
 
     if (n_checkpoints <= 0) {
-        *gb = *gb_tmp;
+        ggml_graph_cpy(gb_tmp, gb);
         return;
     }
 
-    struct hash_map * replacements = new_hash_map();
+    struct hash_map * replacements = ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints);
 
     // insert checkpoints in replacements
     for (int i = 0; i < n_checkpoints; ++i) {
-        size_t k = hash_find(replacements->keys, checkpoints[i]);
-        GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
-        GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
-        replacements->keys[k] = checkpoints[i];
-        replacements->vals[k] = checkpoints[i];
+        size_t k = ggml_hash_find(replacements->set, checkpoints[i]);
+        GGML_ASSERT(k != GGML_HASHTABLE_FULL); // assert that not full
+        GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite
+        replacements->set.keys[k] = checkpoints[i];
+        replacements->vals[k]     = checkpoints[i];
     }
 
-    *gb = *gf;
+    ggml_graph_cpy(gf, gb);
     // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
     // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
     // by recomputing them from checkpoints
@@ -17405,21 +17491,21 @@ void ggml_build_backward_gradient_checkpointing(
         ggml_build_forward_expand(gb, node);
     }
 
-    free_hash_map(replacements);
+    ggml_hash_map_free(replacements);
 }
 
 // functions to change gradients considering the case that input a might be initial gradient with zero value
 
-static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
-    if (hash_contains(zero_table, a)) {
+static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) {
+    if (ggml_hash_contains(zero_table, a)) {
         return b;
     } else {
         return ggml_add_impl(ctx, a, b, false);
     }
 }
 
-static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
-    if (hash_contains(zero_table, a)) {
+static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set zero_table) {
+    if (ggml_hash_contains(zero_table, a)) {
         struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
         return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
     } else {
@@ -17427,23 +17513,23 @@ static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct gg
     }
 }
 
-static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
-    if (hash_contains(zero_table, a)) {
+static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) {
+    if (ggml_hash_contains(zero_table, a)) {
         return ggml_repeat(ctx, b, a);
     } else {
         return ggml_add1_impl(ctx, a, b, false);
     }
 }
 
-static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
-    if (hash_contains(zero_table, a)) {
+static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) {
+    if (ggml_hash_contains(zero_table, a)) {
         return ggml_neg(ctx, b);
     } else {
         return ggml_sub_impl(ctx, a, b, false);
     }
 }
 
-static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) {
+static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
     struct ggml_tensor * src0 = tensor->src[0];
     struct ggml_tensor * src1 = tensor->src[1];
 
@@ -18260,7 +18346,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
     }
 
     // check if already visited
-    if (hash_insert(cgraph->visited_hash_table, node)) {
+    if (ggml_hash_insert(cgraph->visited_hash_table, node) == GGML_HASHTABLE_ALREADY_EXISTS) {
         return;
     }
 
@@ -18276,7 +18362,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
 
     if (node->op == GGML_OP_NONE && node->grad == NULL) {
         // reached a leaf node, not part of the gradient graph (e.g. a constant)
-        GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES);
+        GGML_ASSERT(cgraph->n_leafs < cgraph->size);
 
         if (strlen(node->name) == 0) {
             ggml_format_name(node, "leaf_%d", cgraph->n_leafs);
@@ -18285,22 +18371,24 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
         cgraph->leafs[cgraph->n_leafs] = node;
         cgraph->n_leafs++;
     } else {
-        GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES);
+        GGML_ASSERT(cgraph->n_nodes < cgraph->size);
 
         if (strlen(node->name) == 0) {
             ggml_format_name(node, "node_%d", cgraph->n_nodes);
         }
 
         cgraph->nodes[cgraph->n_nodes] = node;
-        cgraph->grads[cgraph->n_nodes] = node->grad;
+        if (cgraph->grads) {
+            cgraph->grads[cgraph->n_nodes] = node->grad;
+        }
         cgraph->n_nodes++;
     }
 }
 
 static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
     if (!expand) {
-        cgraph->n_nodes = 0;
-        cgraph->n_leafs = 0;
+        // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand
+        ggml_graph_clear(cgraph);
     }
 
     const int n0 = cgraph->n_nodes;
@@ -18321,25 +18409,6 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
     ggml_build_forward_impl(cgraph, tensor, true);
 }
 
-struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
-    struct ggml_cgraph result = {
-        /*.n_nodes      =*/ 0,
-        /*.n_leafs      =*/ 0,
-        /*.nodes        =*/ { NULL },
-        /*.grads        =*/ { NULL },
-        /*.leafs        =*/ { NULL },
-        /*.hash_table   =*/ { NULL },
-        /*.order        =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
-        /*.perf_runs    =*/ 0,
-        /*.perf_cycles  =*/ 0,
-        /*.perf_time_us =*/ 0,
-    };
-
-    ggml_build_forward_impl(&result, tensor, false);
-
-    return result;
-}
-
 void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
     GGML_ASSERT(gf->n_nodes > 0);
 
@@ -18356,11 +18425,10 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
     }
 
     // remember original gradients which start with zero values
-    void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
-    memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
+    struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
     for (int i = 0; i < gf->n_nodes; i++) {
         if (gf->grads[i]) {
-            hash_insert(zero_table, gf->grads[i]);
+            ggml_hash_insert(zero_table, gf->grads[i]);
         }
     }
 
@@ -18383,26 +18451,54 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
         }
     }
 
-    free(zero_table);
+    ggml_hash_set_free(zero_table);
 }
 
-struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
-    struct ggml_cgraph result = *gf;
-    ggml_build_backward_expand(ctx, gf, &result, keep);
-    return result;
+static size_t ggml_graph_nbytes(size_t size, bool grads) {
+    size_t nbytes = sizeof(struct ggml_cgraph);
+    nbytes += size * sizeof(struct ggml_tensor *) * 2; // leafs + nodes
+    if (grads) {
+        nbytes += size * sizeof(struct ggml_tensor *); // grads
+    }
+    nbytes += ggml_hash_size(size * 2) * sizeof(struct ggml_tensor *); // hash set
+    return nbytes;
 }
 
-struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
-    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, GGML_GRAPH_SIZE);
+size_t ggml_graph_overhead_custom(size_t size, bool grads) {
+    return GGML_OBJECT_SIZE + GGML_PAD(ggml_graph_nbytes(size, grads), GGML_MEM_ALIGN);
+}
+
+size_t ggml_graph_overhead(void) {
+    return ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, false);
+}
+
+struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) {
+    const size_t obj_size = ggml_graph_nbytes(size, grads);
+    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, obj_size);
     struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
 
+    struct ggml_tensor ** data_start = (struct ggml_tensor **) (cgraph + 1);
+
+    size_t hash_size = ggml_hash_size(size * 2);
+    struct ggml_tensor ** nodes_ptr = data_start;
+    struct ggml_tensor ** leafs_ptr = nodes_ptr + size;
+    struct ggml_tensor ** hash_keys_ptr = leafs_ptr + size;
+    struct ggml_tensor ** grads_ptr = grads ? hash_keys_ptr + hash_size : NULL;
+
+    // check that we allocated the correct amount of memory
+    assert(obj_size == (size_t) (
+        (grads ? (char *)(grads_ptr + size) : (char *)(hash_keys_ptr + hash_size)) - (char *)cgraph));
+
+    memset(hash_keys_ptr, 0, hash_size * sizeof(struct ggml_tensor *));
+
     *cgraph = (struct ggml_cgraph) {
+        /*.size         =*/ size,
         /*.n_nodes      =*/ 0,
         /*.n_leafs      =*/ 0,
-        /*.nodes        =*/ { NULL },
-        /*.grads        =*/ { NULL },
-        /*.leafs        =*/ { NULL },
-        /*.hash_table   =*/ { NULL },
+        /*.nodes        =*/ nodes_ptr,
+        /*.grads        =*/ grads_ptr,
+        /*.leafs        =*/ leafs_ptr,
+        /*.hash_table   =*/ { hash_size, hash_keys_ptr },
         /*.order        =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
         /*.perf_runs    =*/ 0,
         /*.perf_cycles  =*/ 0,
@@ -18412,14 +18508,85 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
     return cgraph;
 }
 
-struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor) {
-    struct ggml_cgraph * cgraph = ggml_new_graph(ctx);
-    ggml_build_forward_impl(cgraph, tensor, false);
+struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
+    return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);
+}
+
+struct ggml_cgraph * ggml_graph_view(struct ggml_context * ctx, struct ggml_cgraph * cgraph0, int i0, int i1) {
+    const size_t obj_size = sizeof(struct ggml_cgraph);
+    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, obj_size);
+    struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
+
+    *cgraph = (struct ggml_cgraph) {
+        /*.size         =*/ 0,
+        /*.n_nodes      =*/ i1 - i0,
+        /*.n_leafs      =*/ 0,
+        /*.nodes        =*/ cgraph0->nodes + i0,
+        /*.grads        =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
+        /*.leafs        =*/ NULL,
+        /*.hash_table   =*/ { 0, NULL },
+        /*.order        =*/ cgraph0->order,
+        /*.perf_runs    =*/ 0,
+        /*.perf_cycles  =*/ 0,
+        /*.perf_time_us =*/ 0,
+    };
+
     return cgraph;
 }
 
-size_t ggml_graph_overhead(void) {
-    return GGML_OBJECT_SIZE + GGML_PAD(GGML_GRAPH_SIZE, GGML_MEM_ALIGN);
+void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
+    GGML_ASSERT(dst->size >= src->n_leafs);
+    GGML_ASSERT(dst->size >= src->n_nodes);
+    GGML_ASSERT(dst->visited_hash_table.size >= src->visited_hash_table.size);
+
+    dst->n_leafs = src->n_leafs;
+    dst->n_nodes = src->n_nodes;
+    dst->order   = src->order;
+
+    for (int i = 0; i < src->n_leafs; ++i) {
+        dst->leafs[i] = src->leafs[i];
+    }
+
+    for (int i = 0; i < src->n_nodes; ++i) {
+        dst->nodes[i] = src->nodes[i];
+    }
+
+    if (src->grads) {
+        GGML_ASSERT(dst->grads != NULL);
+        for (int i = 0; i < src->n_nodes; ++i) {
+            dst->grads[i] = src->grads[i];
+        }
+    }
+
+    for (size_t i = 0; i < src->visited_hash_table.size; ++i) {
+        if (src->visited_hash_table.keys[i]) {
+            ggml_hash_insert(dst->visited_hash_table, src->visited_hash_table.keys[i]);
+        }
+    }
+}
+
+struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
+    struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
+    ggml_graph_cpy(cgraph, result);
+    return result;
+}
+
+void ggml_graph_reset(struct ggml_cgraph * cgraph) {
+    GGML_ASSERT(cgraph->grads != NULL);
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * grad = cgraph->grads[i];
+
+        if (grad) {
+            ggml_set_zero(grad);
+        }
+    }
+}
+
+void ggml_graph_clear(struct ggml_cgraph * cgraph) {
+    cgraph->n_leafs = 0;
+    cgraph->n_nodes = 0;
+    memset(cgraph->visited_hash_table.keys, 0, cgraph->visited_hash_table.size * sizeof(struct ggml_tensor *));
 }
 
 //
@@ -18572,13 +18739,252 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
     node->perf_time_us += time_us_cur;
 }
 
+static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
+    int n_tasks = 0;
+
+    switch (node->op) {
+        case GGML_OP_CPY:
+        case GGML_OP_DUP:
+        case GGML_OP_ADD:
+        case GGML_OP_ADD1:
+        case GGML_OP_ACC:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_SUB:
+        case GGML_OP_DIV:
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_LOG:
+        case GGML_OP_SUM:
+        case GGML_OP_SUM_ROWS:
+        case GGML_OP_MEAN:
+        case GGML_OP_ARGMAX:
+        case GGML_OP_REPEAT:
+        case GGML_OP_REPEAT_BACK:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(node)) {
+                case GGML_UNARY_OP_ABS:
+                case GGML_UNARY_OP_SGN:
+                case GGML_UNARY_OP_NEG:
+                case GGML_UNARY_OP_STEP:
+                case GGML_UNARY_OP_TANH:
+                case GGML_UNARY_OP_ELU:
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_LEAKY:
+                    {
+                        n_tasks = 1;
+                    } break;
+
+                case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_SILU:
+                    {
+                        n_tasks = n_threads;
+                    } break;
+            }
+            break;
+        case GGML_OP_SILU_BACK:
+        case GGML_OP_MUL:
+        case GGML_OP_NORM:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_RMS_NORM_BACK:
+        case GGML_OP_GROUP_NORM:
+        case GGML_OP_CONCAT:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_MUL_MAT:
+            {
+                n_tasks = n_threads;
+
+                // TODO: use different scheduling for different matrix sizes
+                //const int nr0 = ggml_nrows(node->src[0]);
+                //const int nr1 = ggml_nrows(node->src[1]);
+
+                //n_tasks = MIN(n_threads, MAX(1, nr0/128));
+                //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
+
+#if defined(GGML_USE_CUBLAS)
+                if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) {
+                    n_tasks = 1; // TODO: this actually is doing nothing
+                                 //       the threads are still spinning
+                }
+#elif defined(GGML_USE_CLBLAST)
+                if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
+                    n_tasks = 1; // TODO: this actually is doing nothing
+                                 //       the threads are still spinning
+                }
+#endif
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) {
+                    n_tasks = 1; // TODO: this actually is doing nothing
+                                 //       the threads are still spinning
+                }
+#endif
+            } break;
+        case GGML_OP_OUT_PROD:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_SCALE:
+        case GGML_OP_SET:
+        case GGML_OP_CONT:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_GET_ROWS:
+        case GGML_OP_GET_ROWS_BACK:
+        case GGML_OP_DIAG:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_DIAG_MASK_ZERO:
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_SOFT_MAX_BACK:
+        case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
+        case GGML_OP_ADD_REL_POS:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_ALIBI:
+            {
+                n_tasks = 1; //TODO
+            } break;
+        case GGML_OP_CLAMP:
+            {
+                n_tasks = 1; //TODO
+            } break;
+        case GGML_OP_CONV_1D:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_CONV_1D_STAGE_0:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_CONV_1D_STAGE_1:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_CONV_2D:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_CONV_2D_STAGE_0:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_CONV_2D_STAGE_1:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_CONV_TRANSPOSE_2D:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_POOL_1D:
+        case GGML_OP_POOL_2D:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_UPSCALE:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_FLASH_ATTN:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_FLASH_FF:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_FLASH_ATTN_BACK:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_WIN_PART:
+        case GGML_OP_WIN_UNPART:
+        case GGML_OP_GET_REL_POS:
+        case GGML_OP_MAP_UNARY:
+        case GGML_OP_MAP_BINARY:
+        case GGML_OP_MAP_CUSTOM1_F32:
+        case GGML_OP_MAP_CUSTOM2_F32:
+        case GGML_OP_MAP_CUSTOM3_F32:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_MAP_CUSTOM1:
+            {
+                struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) node->op_params;
+                if (p->n_tasks == GGML_N_TASKS_MAX) {
+                    n_tasks = n_threads;
+                } else {
+                    n_tasks = MIN(p->n_tasks, n_threads);
+                }
+            } break;
+        case GGML_OP_MAP_CUSTOM2:
+            {
+                struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) node->op_params;
+                if (p->n_tasks == GGML_N_TASKS_MAX) {
+                    n_tasks = n_threads;
+                } else {
+                    n_tasks = MIN(p->n_tasks, n_threads);
+                }
+            } break;
+        case GGML_OP_MAP_CUSTOM3:
+            {
+                struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) node->op_params;
+                if (p->n_tasks == GGML_N_TASKS_MAX) {
+                    n_tasks = n_threads;
+                } else {
+                    n_tasks = MIN(p->n_tasks, n_threads);
+                }
+            } break;
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_NONE:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+    assert(n_tasks > 0);
+
+    return n_tasks;
+}
+
 static thread_ret_t ggml_graph_compute_thread(void * data) {
     struct ggml_compute_state * state = (struct ggml_compute_state *) data;
 
     const struct ggml_cgraph * cgraph = state->shared->cgraph;
     const struct ggml_cplan  * cplan  = state->shared->cplan;
 
-    const int * n_tasks_arr = cplan->n_tasks;
     const int   n_threads   = state->shared->n_threads;
 
     set_numa_thread_affinity(state->ith, n_threads);
@@ -18603,9 +19009,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
 
             if (node_n != -1) {
                 /* FINALIZE */
-                struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
+                struct ggml_tensor * node = cgraph->nodes[node_n];
                 if (GGML_OP_HAS_FINALIZE[node->op]) {
-                    params.nth = n_tasks_arr[node_n];
+                    params.nth = ggml_get_n_tasks(node, n_threads);
                     ggml_compute_forward(&params, node);
                 }
                 ggml_graph_compute_perf_stats_node(node, state->shared);
@@ -18616,7 +19022,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                 GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
 
                 struct ggml_tensor * node = cgraph->nodes[node_n];
-                const int n_tasks = n_tasks_arr[node_n];
+                const int n_tasks = ggml_get_n_tasks(node, n_threads);
 
                 state->shared->perf_node_start_cycles  = ggml_perf_cycles();
                 state->shared->perf_node_start_time_us = ggml_perf_time_us();
@@ -18674,7 +19080,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
 
         /* COMPUTE */
         struct ggml_tensor * node = cgraph->nodes[node_n];
-        const int n_tasks = n_tasks_arr[node_n];
+        const int n_tasks = ggml_get_n_tasks(node, n_threads);
 
         struct ggml_compute_params params = {
             /*.type  =*/ GGML_TASK_COMPUTE,
@@ -18708,122 +19114,46 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
 
         struct ggml_tensor * node = cgraph->nodes[i];
 
+        size_t cur = 0;
+
         switch (node->op) {
             case GGML_OP_CPY:
             case GGML_OP_DUP:
                 {
                     n_tasks = n_threads;
 
-                    size_t cur = 0;
                     if (ggml_is_quantized(node->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
                     }
-
-                    work_size = MAX(work_size, cur);
                 } break;
             case GGML_OP_ADD:
             case GGML_OP_ADD1:
                 {
                     n_tasks = n_threads;
 
-                    size_t cur = 0;
-
                     if (ggml_is_quantized(node->src[0]->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
                     }
-
-                    work_size = MAX(work_size, cur);
                 } break;
             case GGML_OP_ACC:
                 {
                     n_tasks = n_threads;
 
-                    size_t cur = 0;
-
                     if (ggml_is_quantized(node->src[0]->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
                     }
-
-                    work_size = MAX(work_size, cur);
                 } break;
-            case GGML_OP_SUB:
-            case GGML_OP_DIV:
-            case GGML_OP_SQR:
-            case GGML_OP_SQRT:
-            case GGML_OP_LOG:
-            case GGML_OP_SUM:
-            case GGML_OP_SUM_ROWS:
-            case GGML_OP_MEAN:
-            case GGML_OP_ARGMAX:
-            case GGML_OP_REPEAT:
-            case GGML_OP_REPEAT_BACK:
-            {
-                    n_tasks = 1;
-                } break;
-
-            case GGML_OP_UNARY:
-                {
-                    switch (ggml_get_unary_op(node)) {
-                        case GGML_UNARY_OP_ABS:
-                        case GGML_UNARY_OP_SGN:
-                        case GGML_UNARY_OP_NEG:
-                        case GGML_UNARY_OP_STEP:
-                        case GGML_UNARY_OP_TANH:
-                        case GGML_UNARY_OP_ELU:
-                        case GGML_UNARY_OP_RELU:
-                        case GGML_UNARY_OP_LEAKY:
-                            {
-                                n_tasks = 1;
-                            } break;
-
-                        case GGML_UNARY_OP_GELU:
-                        case GGML_UNARY_OP_GELU_QUICK:
-                        case GGML_UNARY_OP_SILU:
-                            {
-                                n_tasks = n_threads;
-                            } break;
-                    }
-                } break;
-            case GGML_OP_SILU_BACK:
-            case GGML_OP_MUL:
-            case GGML_OP_NORM:
-            case GGML_OP_RMS_NORM:
-            case GGML_OP_RMS_NORM_BACK:
-            case GGML_OP_GROUP_NORM:
-                {
-                    n_tasks = n_threads;
-                } break;
-            case GGML_OP_CONCAT:
             case GGML_OP_MUL_MAT:
                 {
-                    n_tasks = n_threads;
-
-                    // TODO: use different scheduling for different matrix sizes
-                    //const int nr0 = ggml_nrows(node->src[0]);
-                    //const int nr1 = ggml_nrows(node->src[1]);
-
-                    //n_tasks = MIN(n_threads, MAX(1, nr0/128));
-                    //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
-
-                    size_t cur = 0;
                     const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
 
-#if defined(GGML_USE_CUBLAS)
-                    if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) {
-                        n_tasks = 1; // TODO: this actually is doing nothing
-                                     //       the threads are still spinning
-                    } else
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
                     if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
-                        n_tasks = 1; // TODO: this actually is doing nothing
-                                     //       the threads are still spinning
                         cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node);
                     } else
 #endif
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
                     if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) {
-                        n_tasks = 1; // TODO: this actually is doing nothing
-                                     //       the threads are still spinning
                         if (node->src[0]->type != GGML_TYPE_F32) {
                             // here we need memory just for single 2D matrix from src0
                             cur = ggml_type_size(GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]);
@@ -18832,62 +19162,18 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
 #endif
                     if (node->src[1]->type != vec_dot_type) {
                         cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type);
-                    } else {
-                        cur = 0;
                     }
-
-                    work_size = MAX(work_size, cur);
                 } break;
             case GGML_OP_OUT_PROD:
                 {
                     n_tasks = n_threads;
 
-                    size_t cur = 0;
-
                     if (ggml_is_quantized(node->src[0]->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
                     }
-
-                    work_size = MAX(work_size, cur);
-                } break;
-            case GGML_OP_SCALE:
-                {
-                    n_tasks = 1;
-                } break;
-            case GGML_OP_SET:
-            case GGML_OP_CONT:
-            case GGML_OP_RESHAPE:
-            case GGML_OP_VIEW:
-            case GGML_OP_PERMUTE:
-            case GGML_OP_TRANSPOSE:
-            case GGML_OP_GET_ROWS:
-            case GGML_OP_GET_ROWS_BACK:
-            case GGML_OP_DIAG:
-                {
-                    n_tasks = 1;
-                } break;
-            case GGML_OP_DIAG_MASK_ZERO:
-            case GGML_OP_DIAG_MASK_INF:
-            case GGML_OP_SOFT_MAX:
-            case GGML_OP_SOFT_MAX_BACK:
-            case GGML_OP_ROPE:
-            case GGML_OP_ROPE_BACK:
-            case GGML_OP_ADD_REL_POS:
-                {
-                    n_tasks = n_threads;
-                } break;
-            case GGML_OP_ALIBI:
-                {
-                    n_tasks = 1; //TODO
-                } break;
-            case GGML_OP_CLAMP:
-                {
-                    n_tasks = 1; //TODO
                 } break;
             case GGML_OP_CONV_1D:
                 {
-                    n_tasks = n_threads;
-
                     GGML_ASSERT(node->src[0]->ne[3] == 1);
                     GGML_ASSERT(node->src[1]->ne[2] == 1);
                     GGML_ASSERT(node->src[1]->ne[3] == 1);
@@ -18908,8 +19194,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                     UNUSED(ne10);
                     UNUSED(ne11);
 
-                    size_t cur = 0;
-
                     if (node->src[0]->type == GGML_TYPE_F16 &&
                         node->src[1]->type == GGML_TYPE_F32) {
                         cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
@@ -18919,21 +19203,9 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                     } else {
                         GGML_ASSERT(false);
                     }
-
-                    work_size = MAX(work_size, cur);
-                } break;
-            case GGML_OP_CONV_1D_STAGE_0:
-                {
-                    n_tasks = n_threads;
-                } break;
-            case GGML_OP_CONV_1D_STAGE_1:
-                {
-                    n_tasks = n_threads;
                 } break;
             case GGML_OP_CONV_TRANSPOSE_1D:
                 {
-                    n_tasks = n_threads;
-
                     GGML_ASSERT(node->src[0]->ne[3] == 1);
                     GGML_ASSERT(node->src[1]->ne[2] == 1);
                     GGML_ASSERT(node->src[1]->ne[3] == 1);
@@ -18945,7 +19217,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                     const int64_t ne10 = node->src[1]->ne[0];  // L
                     const int64_t ne11 = node->src[1]->ne[1];  // Cin
 
-                    size_t cur = 0;
                     if (node->src[0]->type == GGML_TYPE_F16 &&
                         node->src[1]->type == GGML_TYPE_F32) {
                         cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
@@ -18957,13 +19228,9 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                     } else {
                         GGML_ASSERT(false);
                     }
-
-                    work_size = MAX(work_size, cur);
                 } break;
             case GGML_OP_CONV_2D:
                 {
-                    n_tasks = n_threads;
-
                     const int64_t ne00 = node->src[0]->ne[0]; // W
                     const int64_t ne01 = node->src[0]->ne[1]; // H
                     const int64_t ne02 = node->src[0]->ne[2]; // C
@@ -18983,8 +19250,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                     UNUSED(ne03);
                     UNUSED(ne2);
 
-                    size_t cur = 0;
-
                     if (node->src[0]->type == GGML_TYPE_F16 &&
                         node->src[1]->type == GGML_TYPE_F32) {
                         // im2col: [N*OH*OW, IC*KH*KW]
@@ -18995,21 +19260,9 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                     } else {
                         GGML_ASSERT(false);
                     }
-
-                    work_size = MAX(work_size, cur);
-                } break;
-            case GGML_OP_CONV_2D_STAGE_0:
-                {
-                    n_tasks = n_threads;
-                } break;
-            case GGML_OP_CONV_2D_STAGE_1:
-                {
-                    n_tasks = n_threads;
                 } break;
             case GGML_OP_CONV_TRANSPOSE_2D:
                 {
-                    n_tasks = n_threads;
-
                     const int64_t ne00 = node->src[0]->ne[0]; // W
                     const int64_t ne01 = node->src[0]->ne[1]; // H
                     const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
@@ -19019,141 +19272,66 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                     const int64_t ne11 = node->src[1]->ne[1]; // H
                     const int64_t ne12 = node->src[1]->ne[2]; // Channels In
 
-                    size_t cur = 0;
                     cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
                     cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
-
-                    work_size = MAX(work_size, cur);
-                } break;
-            case GGML_OP_POOL_1D:
-            case GGML_OP_POOL_2D:
-                {
-                    n_tasks = 1;
-                } break;
-            case GGML_OP_UPSCALE:
-                {
-                    n_tasks = n_threads;
                 } break;
             case GGML_OP_FLASH_ATTN:
                 {
                     n_tasks = n_threads;
 
-                    size_t cur = 0;
-
                     const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
 
                     if (node->src[1]->type == GGML_TYPE_F32) {
                         cur  = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
                         cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
-                    }
-
-                    if (node->src[1]->type == GGML_TYPE_F16) {
+                    } else if (node->src[1]->type == GGML_TYPE_F16) {
                         cur  = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
                         cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
                     }
-
-                    work_size = MAX(work_size, cur);
                 } break;
             case GGML_OP_FLASH_FF:
                 {
                     n_tasks = n_threads;
 
-                    size_t cur = 0;
-
                     if (node->src[1]->type == GGML_TYPE_F32) {
                         cur  = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
                         cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
-                    }
-
-                    if (node->src[1]->type == GGML_TYPE_F16) {
+                    } else if (node->src[1]->type == GGML_TYPE_F16) {
                         cur  = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
                         cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
                     }
-
-                    work_size = MAX(work_size, cur);
                 } break;
             case GGML_OP_FLASH_ATTN_BACK:
                 {
                     n_tasks = n_threads;
 
-                    size_t cur = 0;
-
                     const int64_t    D = node->src[0]->ne[0];
                     const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
                     const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
                     if (node->src[1]->type == GGML_TYPE_F32) {
                         cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
                         cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
-                    }
-
-                    if (node->src[1]->type == GGML_TYPE_F16) {
+                    } else if (node->src[1]->type == GGML_TYPE_F16) {
                         cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
                         cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
                     }
-
-                    work_size = MAX(work_size, cur);
-                } break;
-            case GGML_OP_WIN_PART:
-            case GGML_OP_WIN_UNPART:
-            case GGML_OP_GET_REL_POS:
-            case GGML_OP_MAP_UNARY:
-            case GGML_OP_MAP_BINARY:
-            case GGML_OP_MAP_CUSTOM1_F32:
-            case GGML_OP_MAP_CUSTOM2_F32:
-            case GGML_OP_MAP_CUSTOM3_F32:
-                {
-                    n_tasks = 1;
-                } break;
-            case GGML_OP_MAP_CUSTOM1:
-                {
-                    struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) node->op_params;
-                    if (p->n_tasks == GGML_N_TASKS_MAX) {
-                        n_tasks = n_threads;
-                    } else {
-                        n_tasks = MIN(p->n_tasks, n_threads);
-                    }
-                } break;
-            case GGML_OP_MAP_CUSTOM2:
-                {
-                    struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) node->op_params;
-                    if (p->n_tasks == GGML_N_TASKS_MAX) {
-                        n_tasks = n_threads;
-                    } else {
-                        n_tasks = MIN(p->n_tasks, n_threads);
-                    }
-                } break;
-            case GGML_OP_MAP_CUSTOM3:
-                {
-                    struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) node->op_params;
-                    if (p->n_tasks == GGML_N_TASKS_MAX) {
-                        n_tasks = n_threads;
-                    } else {
-                        n_tasks = MIN(p->n_tasks, n_threads);
-                    }
                 } break;
+
             case GGML_OP_CROSS_ENTROPY_LOSS:
                 {
                     n_tasks = n_threads;
 
-                    size_t cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
-
-                    work_size = MAX(work_size, cur);
-                } break;
-            case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
-                {
-                    n_tasks = n_threads;
-                } break;
-            case GGML_OP_NONE:
-                {
-                    n_tasks = 1;
+                    cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
                 } break;
             case GGML_OP_COUNT:
                 {
                     GGML_ASSERT(false);
                 } break;
+            default:
+                break;
         }
 
-        cplan.n_tasks[i] = n_tasks;
+        work_size = MAX(work_size, cur);
     }
 
     if (work_size > 0) {
@@ -19175,12 +19353,6 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
         if (cplan->work_size > 0) {
             GGML_ASSERT(cplan->work_data);
         }
-
-        for (int i = 0; i < cgraph->n_nodes; ++i) {
-            if (cgraph->nodes[i]->op != GGML_OP_NONE) {
-                GGML_ASSERT(cplan->n_tasks[i] > 0);
-            }
-        }
     }
 
     const int n_threads = cplan->n_threads;
@@ -19253,16 +19425,6 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
     return compute_status;
 }
 
-void ggml_graph_reset(struct ggml_cgraph * cgraph) {
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        struct ggml_tensor * grad = cgraph->grads[i];
-
-        if (grad) {
-            ggml_set_zero(grad);
-        }
-    }
-}
-
 void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
     struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
 
@@ -19389,12 +19551,12 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
             const uint32_t magic   = GGML_FILE_MAGIC;
             const uint32_t version = GGML_FILE_VERSION;
             const uint32_t n_leafs = cgraph->n_leafs;
-            const uint32_t nodes   = cgraph->n_nodes;
+            const uint32_t n_nodes = cgraph->n_nodes;
 
             fwrite(&magic,     sizeof(uint32_t), 1, fout);
             fwrite(&version,   sizeof(uint32_t), 1, fout);
             fwrite(&n_leafs,   sizeof(uint32_t), 1, fout);
-            fwrite(&nodes,     sizeof(uint32_t), 1, fout);
+            fwrite(&n_nodes,   sizeof(uint32_t), 1, fout);
             fwrite(&size_eval, sizeof(uint64_t), 1, fout);
         }
 
@@ -19482,7 +19644,7 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
                             if (idx == -1) {
                                 for (int k = 0; k < cgraph->n_nodes; ++k) {
                                     if (args[j] == cgraph->nodes[k]) {
-                                        idx = GGML_MAX_NODES + k;
+                                        idx = cgraph->n_leafs + k;
                                         break;
                                     }
                                 }
@@ -19509,11 +19671,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
     }
 }
 
-struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) {
+struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) {
     assert(*ctx_data == NULL);
     assert(*ctx_eval == NULL);
 
-    struct ggml_cgraph result = { 0 };
+    struct ggml_cgraph * result = NULL;
 
     struct ggml_tensor * data = NULL;
 
@@ -19585,13 +19747,11 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
         const uint32_t n_leafs   = *(const uint32_t *) ptr; ptr += sizeof(n_leafs);
         const uint32_t n_nodes   = *(const uint32_t *) ptr; ptr += sizeof(n_nodes);
         const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval);
-
-        result.n_leafs = n_leafs;
-        result.n_nodes = n_nodes;
+        const int     graph_size = MAX(n_leafs, n_nodes);
 
         // create the data context
         {
-            const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead();
+            const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph_size, false);
 
             struct ggml_init_params params = {
                 .mem_size   = size_eval + overhead,
@@ -19607,6 +19767,12 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
             }
         }
 
+        result = ggml_new_graph_custom(*ctx_eval, graph_size, false);
+
+        result->n_leafs = n_leafs;
+        result->n_nodes = n_nodes;
+
+
         // leafs
         {
             uint32_t type;
@@ -19645,7 +19811,7 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                     tensor->nb[j] = nb[j];
                 }
 
-                result.leafs[i] = tensor;
+                result->leafs[i] = tensor;
 
                 ptr += ggml_nbytes(tensor);
 
@@ -19697,10 +19863,10 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                         continue;
                     }
 
-                    if (arg_idx < GGML_MAX_NODES) {
-                        args[j] = result.leafs[arg_idx];
+                    if (arg_idx < result->n_leafs) {
+                        args[j] = result->leafs[arg_idx];
                     } else {
-                        args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
+                        args[j] = result->nodes[arg_idx - result->n_leafs];
                     }
                 }
 
@@ -19752,7 +19918,7 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                     tensor->src[j] = args[j];
                 }
 
-                result.nodes[i] = tensor;
+                result->nodes[i] = tensor;
 
                 fprintf(stderr, "%s: loaded node %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, ggml_nbytes(tensor));
             }
@@ -20657,10 +20823,11 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
         case GGML_OPT_ADAM:
             {
                 result = (struct ggml_opt_params) {
-                    .type      = GGML_OPT_ADAM,
-                    .n_threads = 1,
-                    .past      = 0,
-                    .delta     = 1e-5f,
+                    .type       = GGML_OPT_ADAM,
+                    .graph_size = GGML_DEFAULT_GRAPH_SIZE,
+                    .n_threads  = 1, // FIXME: GGML_DEFAULT_N_THREADS ?
+                    .past       = 0,
+                    .delta      = 1e-5f,
 
                     .max_no_improvement = 100,
 
@@ -20687,10 +20854,11 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
         case GGML_OPT_LBFGS:
             {
                 result = (struct ggml_opt_params) {
-                    .type      = GGML_OPT_LBFGS,
-                    .n_threads = 1,
-                    .past      = 0,
-                    .delta     = 1e-5f,
+                    .type       = GGML_OPT_LBFGS,
+                    .graph_size = GGML_DEFAULT_GRAPH_SIZE,
+                    .n_threads  = 1,
+                    .past       = 0,
+                    .delta      = 1e-5f,
 
                     .max_no_improvement = 0,
 
@@ -20832,14 +21000,11 @@ enum ggml_opt_result ggml_opt_resume(
         struct ggml_tensor * f) {
 
     // build forward + backward compute graphs
-    struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
-    struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
-
-    struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
-    struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
+    struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, opt->params.graph_size, true);
+    ggml_build_forward_expand(gf, f);
 
-    *gf = ggml_build_forward (f);
-    *gb = ggml_build_backward(ctx, gf, true);
+    struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
+    ggml_build_backward_expand(ctx, gf, gb, true);
 
     return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
 }
index cd0cc5ffd2df45bfff578c76ecab2fd6e4f30d26..2050ad2f3fc453937b684fa17c9fc69ad49a5c42 100644 (file)
@@ -132,15 +132,17 @@ int main(int argc, const char ** argv) {
     {
         dst2 = ggml_mul_mat(ctx0, s0_f32, s1_f32);
 
-        struct ggml_cgraph gf = ggml_build_forward(dst2);
-        ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+        ggml_build_forward_expand(gf, dst2);
+        ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
     }
 
     {
         dst3 = ggml_mul_mat(ctx0, s0_f16, s1_f32);
 
-        struct ggml_cgraph gf = ggml_build_forward(dst3);
-        ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+        ggml_build_forward_expand(gf, dst3);
+        ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
     }
 
     bool ok_blas = true;
index fbd7c89249d9b2d6d328f7f67e6de9ddcd7e117e..116266e9fd547f57bc7f03cb7281830ccdfbfbb6 100644 (file)
@@ -52,6 +52,9 @@ void check_tensor(struct ggml_tensor * t, float * expected_t_d, int ne0, int ne1
             for (int i0 = 0; i0 < ne0; ++i0) {
                 float expected = *(expected_t_d + i2 * ne1 * ne0 + i1 * ne0 + i0);
                 float actual = ggml_get_data_f32(t)[i2 * ne1 * ne0 + i1 * ne0 + i0];
+                if (expected != actual) {
+                    printf("expected %.1f, got %.1f\n", expected, actual);
+                }
                 GGML_ASSERT(expected == actual);
             }
         }
@@ -100,13 +103,17 @@ void test_conv_transpose_1d(void) {
         struct ggml_tensor * out_2 = ggml_conv_transpose_1d(ctx, k, t, 2 /* s0 */, 0 /* p0 */, 1 /* d0 */);
         struct ggml_tensor * out_3 = ggml_conv_transpose_1d(ctx, k, t, 3 /* s0 */, 0 /* p0 */, 1 /* d0 */);
 
-        struct ggml_cgraph gf_1 = ggml_build_forward(out_1);
-        struct ggml_cgraph gf_2 = ggml_build_forward(out_2);
-        struct ggml_cgraph gf_3 = ggml_build_forward(out_3);
+        struct ggml_cgraph * gf_1 = ggml_new_graph(ctx);
+        struct ggml_cgraph * gf_2 = ggml_new_graph(ctx);
+        struct ggml_cgraph * gf_3 = ggml_new_graph(ctx);
+
+        ggml_build_forward_expand(gf_1, out_1);
+        ggml_build_forward_expand(gf_2, out_2);
+        ggml_build_forward_expand(gf_3, out_3);
 
-        ggml_graph_compute_with_ctx(ctx, &gf_1, 1);
-        ggml_graph_compute_with_ctx(ctx, &gf_2, 1);
-        ggml_graph_compute_with_ctx(ctx, &gf_3, 1);
+        ggml_graph_compute_with_ctx(ctx, gf_1, 1);
+        ggml_graph_compute_with_ctx(ctx, gf_2, 1);
+        ggml_graph_compute_with_ctx(ctx, gf_3, 1);
 
         check_tensor(out_1, (float*)expected_out_1, 4, 3, 1);
         check_tensor(out_2, (float*)expected_out_2, 6, 3, 1);
@@ -203,13 +210,17 @@ void test_conv_transpose_2d(void) {
         struct ggml_tensor * out_2 = ggml_conv_transpose_2d_p0(ctx, k, t, 2);
         struct ggml_tensor * out_3 = ggml_conv_transpose_2d_p0(ctx, k, t, 3);
 
-        struct ggml_cgraph gf_1 = ggml_build_forward(out_1);
-        struct ggml_cgraph gf_2 = ggml_build_forward(out_2);
-        struct ggml_cgraph gf_3 = ggml_build_forward(out_3);
+        struct ggml_cgraph * gf_1 = ggml_new_graph(ctx);
+        struct ggml_cgraph * gf_2 = ggml_new_graph(ctx);
+        struct ggml_cgraph * gf_3 = ggml_new_graph(ctx);
+
+        ggml_build_forward_expand(gf_1, out_1);
+        ggml_build_forward_expand(gf_2, out_2);
+        ggml_build_forward_expand(gf_3, out_3);
 
-        ggml_graph_compute_with_ctx(ctx, &gf_1, 1);
-        ggml_graph_compute_with_ctx(ctx, &gf_2, 1);
-        ggml_graph_compute_with_ctx(ctx, &gf_3, 1);
+        ggml_graph_compute_with_ctx(ctx, gf_1, 1);
+        ggml_graph_compute_with_ctx(ctx, gf_2, 1);
+        ggml_graph_compute_with_ctx(ctx, gf_3, 1);
 
         // printf("in\n");
         // printf_tensor(t);
index ec261ec83584d67838730aa2db1c2b671a3fb220..e96aa67c9bf6629336369998ead380fd7d05cd93 100644 (file)
@@ -150,9 +150,10 @@ int main(int argc, const char** argv) {
 
         struct ggml_tensor * m1 = ggml_map_custom1(ctx, t, custom1, 2, NULL);
 
-        struct ggml_cgraph graph = ggml_build_forward(m1);
+        struct ggml_cgraph * graph = ggml_new_graph(ctx);
+        ggml_build_forward_expand(graph, m1);
 
-        ggml_graph_compute_with_ctx(ctx, &graph, 4);
+        ggml_graph_compute_with_ctx(ctx, graph, 4);
 
         const float * output = ggml_get_data_f32(m1);
 
@@ -175,9 +176,10 @@ int main(int argc, const char** argv) {
 
         struct ggml_tensor * m2 = ggml_map_custom2(ctx, t1, t2, custom2, GGML_N_TASKS_MAX, g_userdata);
 
-        struct ggml_cgraph graph = ggml_build_forward(m2);
+        struct ggml_cgraph * graph = ggml_new_graph(ctx);
+        ggml_build_forward_expand(graph, m2);
 
-        ggml_graph_compute_with_ctx(ctx, &graph, 4);
+        ggml_graph_compute_with_ctx(ctx, graph, 4);
 
         const float * output = ggml_get_data_f32(m2);
 
@@ -203,9 +205,10 @@ int main(int argc, const char** argv) {
 
         struct ggml_tensor * m3 = ggml_map_custom3(ctx, t1, t2, t3, custom3, 1, g_userdata);
 
-        struct ggml_cgraph graph = ggml_build_forward(m3);
+        struct ggml_cgraph * graph = ggml_new_graph(ctx);
+        ggml_build_forward_expand(graph, m3);
 
-        ggml_graph_compute_with_ctx(ctx, &graph, 4);
+        ggml_graph_compute_with_ctx(ctx, graph, 4);
 
         const float * output = ggml_get_data_f32(m3);
 
index 0a559b27ab370e7af3dac9d111bc60a48b812782..7fe9154ddbb1663106a4845dcf8d515070873e13 100644 (file)
@@ -231,9 +231,10 @@ static bool check_gradient(
         printf("GGML_N_THREADS = %d\n", n_threads);
     }
 
-    struct ggml_cgraph * gf = ggml_build_forward_ctx(ctx0, f);
-    struct ggml_cgraph * gb = ggml_new_graph(ctx0);
-    *gb = *gf;
+    struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+    struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+    ggml_build_forward_expand(gf, f);
+    ggml_graph_cpy(gf, gb);
     ggml_build_backward_expand(ctx0, gf, gb, false);
 
     ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
index 6212da41a6bb144c10a5f0ce66a9556c52de889b..ee52b7a6cd3b7799f15d11f5286b82f60ba2d470 100644 (file)
@@ -97,16 +97,18 @@ bool check_gradient(
         float max_error_rel) {
     const int n_threads = 1;
 
-    struct ggml_cgraph gf = ggml_build_forward (f);
-    struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+    struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+    ggml_build_forward_expand(gf, f);
+    struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
+    ggml_build_backward_expand(ctx0, gf, gb, false);
 
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
-    ggml_graph_reset  (&gf);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
+    ggml_graph_reset  (gf);
     ggml_set_f32      (f->grad, 1.0f);
-    ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+    ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
-    ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
-    ggml_graph_dump_dot(&gb, &gf,  "test-grad0-backward.dot");
+    ggml_graph_dump_dot(gf, NULL, "test-grad0-forward.dot");
+    ggml_graph_dump_dot(gb, gf,   "test-grad0-backward.dot");
 
     for (int i = 0; i < nargs; ++i) {
         const int64_t nelements = ggml_nelements(x[i]);
@@ -115,12 +117,12 @@ bool check_gradient(
             const float x0 = get_element(x[i], k);
 
             set_element(x[i], k, x0 + eps);
-            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+            ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
             const float f0 = ggml_get_f32_1d(f, 0);
 
             set_element(x[i], k, x0 - eps);
-            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+            ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
             const float f1 = ggml_get_f32_1d(f, 0);
 
@@ -129,9 +131,9 @@ bool check_gradient(
             set_element(x[i], k, x0);
 
             // compute gradient using backward graph
-            ggml_graph_reset  (&gf);
+            ggml_graph_reset  (gf);
             ggml_set_f32      (f->grad, 1.0f);
-            ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+            ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
             const float g1 = get_element(x[i]->grad, k);
 
@@ -282,8 +284,9 @@ int main(int argc, const char ** argv) {
                 if (ndims <= 2) {
                     check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
                 } else {
-                    struct ggml_cgraph gf = ggml_build_forward(m);
-                    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+                    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+                    ggml_build_forward_expand(gf, m);
+                    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
                 }
 
                 check_mat_mul(m, x[1], x[0]);
@@ -318,8 +321,9 @@ int main(int argc, const char ** argv) {
                 if (ndims <= 2) {
                     check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
                 } else {
-                    struct ggml_cgraph gf = ggml_build_forward(m);
-                    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+                    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+                    ggml_build_forward_expand(gf, m);
+                    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
                 }
 
                 check_mat_mul(m, x[1], x[0]);
index bb8af59620b146fc9c2e93d048cdb6a162c6937f..2c9997fca770503f5d105cb76f1f85dfebb5ea7d 100644 (file)
@@ -109,10 +109,11 @@ int main(void) {
     struct ggml_tensor * d  = ggml_sub(ctx, c, ab);
     struct ggml_tensor * e  = ggml_sum(ctx, ggml_sqr(ctx, d));
 
-    struct ggml_cgraph ge = ggml_build_forward(e);
-    ggml_graph_reset(&ge);
+    struct ggml_cgraph * ge = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
+    ggml_build_forward_expand(ge, e);
+    ggml_graph_reset(ge);
 
-    ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
+    ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1);
 
     const float fe = ggml_get_f32_1d(e, 0);
     printf("%s: e = %.4f\n", __func__, fe);
@@ -121,9 +122,9 @@ int main(void) {
 
     ggml_opt(ctx, opt_params, e);
 
-    ggml_graph_reset(&ge);
+    ggml_graph_reset(ge);
 
-    ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
+    ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1);
 
     const float fe_opt = ggml_get_f32_1d(e, 0);
     printf("%s: original  e = %.4f\n", __func__, fe);
index cdf00f4ec989f129d38ddbdb41ebe6b040da260a..e4626ec63dd963ec8f1d6a532957909fc70fe3fa 100644 (file)
@@ -30,9 +30,10 @@ int main(int argc, const char** argv) {
         GGML_ASSERT(t_pooled->ne[1] == 2);
         GGML_ASSERT(t_pooled->ne[2] == 1);
 
-        struct ggml_cgraph graph = ggml_build_forward(t_pooled);
+        struct ggml_cgraph * graph = ggml_new_graph(ctx);
+        ggml_build_forward_expand(graph, t_pooled);
 
-        ggml_graph_compute_with_ctx(ctx, &graph, 4);
+        ggml_graph_compute_with_ctx(ctx, graph, 4);
 
         const float * output = ggml_get_data_f32(t_pooled);
 
@@ -57,9 +58,10 @@ int main(int argc, const char** argv) {
         GGML_ASSERT(t_pooled->ne[1] == 2);
         GGML_ASSERT(t_pooled->ne[2] == 1);
 
-        struct ggml_cgraph graph = ggml_build_forward(t_pooled);
+        struct ggml_cgraph * graph = ggml_new_graph(ctx);
+        ggml_build_forward_expand(graph, t_pooled);
 
-        ggml_graph_compute_with_ctx(ctx, &graph, 4);
+        ggml_graph_compute_with_ctx(ctx, graph, 4);
 
         const float * output = ggml_get_data_f32(t_pooled);
         GGML_ASSERT(output[0] == 3);
@@ -84,9 +86,10 @@ int main(int argc, const char** argv) {
         GGML_ASSERT(t_pooled->ne[2] == 2);
         GGML_ASSERT(t_pooled->ne[3] == 1);
 
-        struct ggml_cgraph graph = ggml_build_forward(t_pooled);
+        struct ggml_cgraph * graph = ggml_new_graph(ctx);
+        ggml_build_forward_expand(graph, t_pooled);
 
-        ggml_graph_compute_with_ctx(ctx, &graph, 4);
+        ggml_graph_compute_with_ctx(ctx, graph, 4);
 
         const float * output = ggml_get_data_f32(t_pooled);
         GGML_ASSERT(output[0] == 17);
@@ -118,9 +121,10 @@ int main(int argc, const char** argv) {
         GGML_ASSERT(t_pooled->ne[2] == 2);
         GGML_ASSERT(t_pooled->ne[3] == 1);
 
-        struct ggml_cgraph graph = ggml_build_forward(t_pooled);
+        struct ggml_cgraph * graph = ggml_new_graph(ctx);
+        ggml_build_forward_expand(graph, t_pooled);
 
-        ggml_graph_compute_with_ctx(ctx, &graph, 4);
+        ggml_graph_compute_with_ctx(ctx, graph, 4);
 
         const float * output = ggml_get_data_f32(t_pooled);
         GGML_ASSERT(output[0] == 33);
index 19960b453ea950c19141b5df217fb49384a64558..47c8438555a82b021bddb40f55515460c3f7df5b 100644 (file)
@@ -69,12 +69,14 @@ int main(int argc, const char** argv) {
         }
 
         struct ggml_tensor * out = ggml_add_rel_pos(ctx, in, rw_f32, rh_f32);
-        struct ggml_cgraph gf = ggml_build_forward(out);
-        ggml_graph_compute_with_ctx(ctx, &gf, 1);
+        struct ggml_cgraph * gf = ggml_new_graph(ctx);
+        ggml_build_forward_expand(gf, out);
+        ggml_graph_compute_with_ctx(ctx, gf, 1);
 
         out_inplace = ggml_add_rel_pos_inplace(ctx, out_inplace, rw_f32, rh_f32);
-        struct ggml_cgraph gf_2 = ggml_build_forward(out_inplace);
-        ggml_graph_compute_with_ctx(ctx, &gf_2, 1);
+        struct ggml_cgraph * gf_2 = ggml_new_graph(ctx);
+        ggml_build_forward_expand(gf_2, out_inplace);
+        ggml_graph_compute_with_ctx(ctx, gf_2, 1);
 
         check_tensor(out, (float*)expected_out, 9, 4, 1);
         check_tensor(out_inplace, (float*)expected_out, 9, 4, 1);
index 9db47d9bcc83204f2361470818d8a650f70b8e2f..5f331105602887881deaf1929359ec2d5a3f4e4b 100644 (file)
@@ -39,9 +39,10 @@ int main(int argc, char ** argv) {
     struct ggml_tensor * Qx = ggml_rope_xpos_inplace(ctx, Q, KQ_pos, n_embd_head, 512.0f, false);
     struct ggml_tensor * Kx = ggml_rope_xpos_inplace(ctx, K, KQ_pos, n_embd_head, 512.0f, true);
 
-    struct ggml_cgraph gf = ggml_build_forward(Qx);
-    ggml_build_forward_expand(&gf, Kx);
-    ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
+    struct ggml_cgraph * gf = ggml_new_graph(ctx);
+    ggml_build_forward_expand(gf, Qx);
+    ggml_build_forward_expand(gf, Kx);
+    ggml_graph_compute_with_ctx(ctx, gf, n_threads);
 
        // expected output for Qx:
     // -0.6009  2.7568  1.9782  2.0182
index c313bf8e1b662a960cab28f1f397a90b9972943b..230aaed856d16b697c1cb40152c7a7620036a43a 100644 (file)
@@ -28,16 +28,18 @@ int main(int argc, const char ** argv) {
 
         ggml_print_objects(ctx0);
 
-        struct ggml_cgraph gf = ggml_build_forward(f);
-        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+        ggml_build_forward_expand(gf, f);
+        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
+        ggml_build_backward_expand(ctx0, gf, gb, false);
 
         ggml_set_f32(x, 2.0f);
         ggml_set_f32(a, 3.0f);
 
-        ggml_graph_reset(&gf);
+        ggml_graph_reset(gf);
         ggml_set_f32(f->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
         printf("f     = %f\n", ggml_get_f32_1d(f, 0));
         printf("df/dx = %f\n", ggml_get_f32_1d(x->grad, 0));
@@ -47,10 +49,10 @@ int main(int argc, const char ** argv) {
 
         ggml_set_f32(x, 3.0f);
 
-        ggml_graph_reset(&gf);
+        ggml_graph_reset(gf);
         ggml_set_f32(f->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
         printf("f     = %f\n", ggml_get_f32_1d(f, 0));
         printf("df/dx = %f\n", ggml_get_f32_1d(x->grad, 0));
@@ -58,8 +60,8 @@ int main(int argc, const char ** argv) {
         GGML_ASSERT(ggml_get_f32_1d(f, 0)       == 27.0f);
         GGML_ASSERT(ggml_get_f32_1d(x->grad, 0) == 18.0f);
 
-        ggml_graph_dump_dot(&gf, NULL, "test1-1-forward.dot");
-        ggml_graph_dump_dot(&gb, &gf,  "test1-1-backward.dot");
+        ggml_graph_dump_dot(gf, NULL, "test1-1-forward.dot");
+        ggml_graph_dump_dot(gb, gf,   "test1-1-backward.dot");
     }
 
     ///////////////////////////////////////////////////////////////
@@ -78,13 +80,15 @@ int main(int argc, const char ** argv) {
 
         struct ggml_tensor * y = ggml_add(ctx0, ggml_mul(ctx0, x1, x1), ggml_mul(ctx0, x1, x2));
 
-        struct ggml_cgraph gf = ggml_build_forward(y);
-        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+        ggml_build_forward_expand(gf, y);
+        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
+        ggml_build_backward_expand(ctx0, gf, gb, false);
 
-        ggml_graph_reset(&gf);
+        ggml_graph_reset(gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
@@ -97,21 +101,23 @@ int main(int argc, const char ** argv) {
         struct ggml_tensor * g1 = x1->grad;
         struct ggml_tensor * g2 = x2->grad;
 
-        struct ggml_cgraph gbb = ggml_build_backward(ctx0, &gb, true);
+        struct ggml_cgraph * gbb = ggml_graph_dup(ctx0, gb);
 
-        ggml_graph_reset(&gb);
+        ggml_build_backward_expand(ctx0, gb, gbb, true);
+
+        ggml_graph_reset(gb);
         ggml_set_f32(g1->grad, 1.0f);
         ggml_set_f32(g2->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx0, &gbb, n_threads);
+        ggml_graph_compute_with_ctx(ctx0, gbb, n_threads);
 
         printf("H * [1, 1] = [ %f %f ]\n", ggml_get_f32_1d(x1->grad, 0), ggml_get_f32_1d(x2->grad, 0));
 
         GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 3.0f);
         GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 1.0f);
 
-        ggml_graph_dump_dot(&gf, NULL, "test1-2-forward.dot");
-        ggml_graph_dump_dot(&gb, &gf,  "test1-2-backward.dot");
+        ggml_graph_dump_dot(gf, NULL, "test1-2-forward.dot");
+        ggml_graph_dump_dot(gb, gf,   "test1-2-backward.dot");
     }
 
     ///////////////////////////////////////////////////////////////
@@ -125,16 +131,18 @@ int main(int argc, const char ** argv) {
 
         struct ggml_tensor * y = ggml_mul(ctx0, ggml_add(ctx0, ggml_mul(ctx0, x1, x1), ggml_mul(ctx0, x1, x2)), x1);
 
-        struct ggml_cgraph gf = ggml_build_forward(y);
-        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+        ggml_build_forward_expand(gf, y);
+        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
+        ggml_build_backward_expand(ctx0, gf, gb, false);
 
         ggml_set_f32(x1, 3.0f);
         ggml_set_f32(x2, 4.0f);
 
-        ggml_graph_reset(&gf);
+        ggml_graph_reset(gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
@@ -144,8 +152,8 @@ int main(int argc, const char ** argv) {
         GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 51.0f);
         GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 9.0f);
 
-        ggml_graph_dump_dot(&gf, NULL, "test1-3-forward.dot");
-        ggml_graph_dump_dot(&gb, &gf,  "test1-3-backward.dot");
+        ggml_graph_dump_dot(gf, NULL, "test1-3-forward.dot");
+        ggml_graph_dump_dot(gb, gf,   "test1-3-backward.dot");
     }
 
     ///////////////////////////////////////////////////////////////
@@ -161,17 +169,19 @@ int main(int argc, const char ** argv) {
 
         struct ggml_tensor * y = ggml_mul(ctx0, ggml_mul(ctx0, ggml_mul(ctx0, x1, x1), ggml_mul(ctx0, x2, x2)), x3);
 
-        struct ggml_cgraph gf = ggml_build_forward(y);
-        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+        ggml_build_forward_expand(gf, y);
+        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
+        ggml_build_backward_expand(ctx0, gf, gb, false);
 
         ggml_set_f32(x1, 1.0f);
         ggml_set_f32(x2, 2.0f);
         ggml_set_f32(x3, 3.0f);
 
-        ggml_graph_reset(&gf);
+        ggml_graph_reset(gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
@@ -187,14 +197,16 @@ int main(int argc, const char ** argv) {
         struct ggml_tensor * g2 = x2->grad;
         struct ggml_tensor * g3 = x3->grad;
 
-        struct ggml_cgraph gbb = ggml_build_backward(ctx0, &gb, true);
+        struct ggml_cgraph * gbb = ggml_graph_dup(ctx0, gb);
+
+        ggml_build_backward_expand(ctx0, gb, gbb, true);
 
-        ggml_graph_reset(&gb);
+        ggml_graph_reset(gb);
         ggml_set_f32(g1->grad, 1.0f);
         ggml_set_f32(g2->grad, 1.0f);
         ggml_set_f32(g3->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx0, &gbb, n_threads);
+        ggml_graph_compute_with_ctx(ctx0, gbb, n_threads);
 
         printf("H * [1, 1, 1] = [ %f %f %f ]\n",
                 ggml_get_f32_1d(x1->grad, 0),
@@ -205,8 +217,8 @@ int main(int argc, const char ** argv) {
         GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 34.0f);
         GGML_ASSERT(ggml_get_f32_1d(x3->grad, 0) == 12.0f);
 
-        ggml_graph_dump_dot(&gf, NULL, "test1-4-forward.dot");
-        ggml_graph_dump_dot(&gb, &gf,  "test1-4-backward.dot");
+        ggml_graph_dump_dot(gf, NULL, "test1-4-forward.dot");
+        ggml_graph_dump_dot(gb, gf,   "test1-4-backward.dot");
     }
 
     ///////////////////////////////////////////////////////////////
@@ -220,16 +232,18 @@ int main(int argc, const char ** argv) {
 
         struct ggml_tensor * y = ggml_sum(ctx0, ggml_mul(ctx0, x1, x2));
 
-        struct ggml_cgraph gf = ggml_build_forward(y);
-        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+        ggml_build_forward_expand(gf, y);
+        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
+        ggml_build_backward_expand(ctx0, gf, gb, false);
 
         ggml_set_f32(x1, 3.0f);
         ggml_set_f32(x2, 5.0f);
 
-        ggml_graph_reset(&gf);
+        ggml_graph_reset(gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f %f %f\n",
@@ -249,8 +263,8 @@ int main(int argc, const char ** argv) {
         GGML_ASSERT(ggml_get_f32_1d(x1->grad, 2) == 5.0f);
         GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) == 3.0f);
 
-        ggml_graph_dump_dot(&gf, NULL, "test1-5-forward.dot");
-        ggml_graph_dump_dot(&gb, &gf,  "test1-5-backward.dot");
+        ggml_graph_dump_dot(gf, NULL, "test1-5-forward.dot");
+        ggml_graph_dump_dot(gb, gf,   "test1-5-backward.dot");
     }
 
     ///////////////////////////////////////////////////////////////
@@ -273,16 +287,18 @@ int main(int argc, const char ** argv) {
                         )
                     );
 
-        struct ggml_cgraph gf = ggml_build_forward(y);
-        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+        ggml_build_forward_expand(gf, y);
+        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
+        ggml_build_backward_expand(ctx0, gf, gb, false);
 
         ggml_set_f32(x1, 3.0f);
         ggml_set_f32(x2, 5.0f);
 
-        ggml_graph_reset(&gf);
+        ggml_graph_reset(gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f %f %f\n",
@@ -302,8 +318,8 @@ int main(int argc, const char ** argv) {
         GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) ==  3.0f);
         GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) ==  3.0f);
 
-        ggml_graph_dump_dot(&gf, NULL, "test1-6-forward.dot");
-        ggml_graph_dump_dot(&gb, &gf,  "test1-6-backward.dot");
+        ggml_graph_dump_dot(gf, NULL, "test1-6-forward.dot");
+        ggml_graph_dump_dot(gb, gf,   "test1-6-backward.dot");
     }
 
     ///////////////////////////////////////////////////////////////
@@ -326,16 +342,18 @@ int main(int argc, const char ** argv) {
                         )
                     );
 
-        struct ggml_cgraph gf = ggml_build_forward(y);
-        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+        ggml_build_forward_expand(gf, y);
+        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
+        ggml_build_backward_expand(ctx0, gf, gb, false);
 
         ggml_set_f32(x1, 3.0f);
         ggml_set_f32(x2, 5.0f);
 
-        ggml_graph_reset(&gf);
+        ggml_graph_reset(gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f %f %f\n",
@@ -355,8 +373,8 @@ int main(int argc, const char ** argv) {
         GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) ==  3.0f);
         GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) ==  3.0f);
 
-        ggml_graph_dump_dot(&gf, NULL, "test1-7-forward.dot");
-        ggml_graph_dump_dot(&gb, &gf,  "test1-7-backward.dot");
+        ggml_graph_dump_dot(gf, NULL, "test1-7-forward.dot");
+        ggml_graph_dump_dot(gb, gf,   "test1-7-backward.dot");
     }
 
     ///////////////////////////////////////////////////////////////
@@ -373,16 +391,18 @@ int main(int argc, const char ** argv) {
                     ggml_sub(ctx0, x1, x2)
                     );
 
-        struct ggml_cgraph gf = ggml_build_forward(y);
-        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+        ggml_build_forward_expand(gf, y);
+        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
+        ggml_build_backward_expand(ctx0, gf, gb, false);
 
         ggml_set_f32(x1, 3.0f);
         ggml_set_f32(x2, 5.0f);
 
-        ggml_graph_reset(&gf);
+        ggml_graph_reset(gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f %f %f\n",
@@ -405,10 +425,10 @@ int main(int argc, const char ** argv) {
         ggml_set_f32(x1, 7.0f);
         ggml_set_f32(x2, 5.0f);
 
-        ggml_graph_reset(&gf);
+        ggml_graph_reset(gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f %f %f\n",
@@ -428,8 +448,8 @@ int main(int argc, const char ** argv) {
         GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) == -1.0f);
         GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) == -1.0f);
 
-        ggml_graph_dump_dot(&gf, NULL, "test1-8-forward.dot");
-        ggml_graph_dump_dot(&gb, &gf,  "test1-8-backward.dot");
+        ggml_graph_dump_dot(gf, NULL, "test1-8-forward.dot");
+        ggml_graph_dump_dot(gb, gf,   "test1-8-backward.dot");
     }
 
     ggml_free(ctx0);