]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : change ggml_graph_compute() API to not require context (#1999)
authorQingyou Meng <redacted>
Fri, 7 Jul 2023 16:24:01 +0000 (00:24 +0800)
committerGitHub <redacted>
Fri, 7 Jul 2023 16:24:01 +0000 (19:24 +0300)
* ggml_graph_compute: deprecate using ggml_context, try resolve issue #287

* rewrite: no longer consider backward compitability; plan and make_plan

* minor: rename ctx as plan; const

* remove ggml_graph_compute from tests/test-grad0.c, but current change breaks backward

* add static ggml_graph_compute_sugar()

* minor: update comments

* reusable buffers

* ggml : more consistent naming + metal fixes

* ggml : fix docs

* tests : disable grad / opt + minor naming changes

* ggml : add ggml_graph_compute_with_ctx()

- backwards compatible API
- deduplicates a lot of copy-paste

* ci : enable test-grad0

* examples : factor out plan allocation into a helper function

* llama : factor out plan stuff into a helper function

* ci : fix env

* llama : fix duplicate symbols + refactor example benchmark

* ggml : remove obsolete assert + refactor n_tasks section

* ggml : fix indentation in switch

* llama : avoid unnecessary bool

* ggml : remove comments from source file and match order in header

---------

Co-authored-by: Georgi Gerganov <redacted>
13 files changed:
.github/workflows/build.yml
examples/baby-llama/baby-llama.cpp
examples/benchmark/benchmark-matmult.cpp
examples/metal/metal.cpp
examples/train-text-from-scratch/train-text-from-scratch.cpp
ggml-metal.h
ggml-metal.m
ggml.c
ggml.h
llama.cpp
tests/CMakeLists.txt
tests/test-grad0.c
tests/test-opt.c

index 12481e8be7cf7ec6adff597717a5b2e755a09d2d..a576139efd0ee6c49d49b0225961a4e33b937a1d 100644 (file)
@@ -16,7 +16,9 @@ on:
     paths: ['**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu']
 
 env:
- BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
+  BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
+  GGML_NLOOP: 3
+  GGML_NITER: 1
 
 jobs:
   ubuntu-focal-make:
@@ -64,7 +66,7 @@ jobs:
         id: cmake_test
         run: |
           cd build
-          ctest --verbose
+          ctest --verbose --timeout 900
 
   ubuntu-latest-cmake-sanitizer:
     runs-on: ubuntu-latest
@@ -99,7 +101,7 @@ jobs:
         id: cmake_test
         run: |
           cd build
-          ctest --verbose
+          ctest --verbose --timeout 900
 
   macOS-latest-make:
     runs-on: macos-latest
@@ -147,10 +149,11 @@ jobs:
         id: cmake_test
         run: |
           cd build
-          ctest --verbose
+          ctest --verbose --timeout 900
 
   windows-latest-cmake:
     runs-on: windows-latest
+
     env:
       OPENBLAS_VERSION: 0.3.23
       OPENCL_VERSION: 2023.04.17
@@ -249,7 +252,7 @@ jobs:
         if: ${{ matrix.build != 'clblast' && (matrix.build != 'avx512' || env.HAS_AVX512F == '1') }} # Test AVX-512 only when possible
         run: |
           cd build
-          ctest -C Release --verbose
+          ctest -C Release --verbose --timeout 900
 
       - name: Get commit hash
         id: commit
index 212f54d32cbad214d05f236e245e8c642987abe2..4965881ecec22a02141ac1fe6bf456c1efcb2d20 100644 (file)
@@ -31,6 +31,17 @@ float frand_normal(struct random_normal_distribution * rnd) {
     return ((r < rnd->min) ? (rnd->min) : (r > rnd->max) ? (rnd->max) : r);
 }
 
+void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
+    struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
+
+    if (plan.work_size > 0) {
+        buf.resize(plan.work_size);
+        plan.work_data = buf.data();
+    }
+
+    ggml_graph_compute(graph, &plan);
+}
+
 struct ggml_tensor * randomize_tensor(
         struct ggml_tensor * tensor,
         int ndims,
@@ -1569,6 +1580,8 @@ int main(int argc, char ** argv) {
     int n_tokens = model.hparams.n_ctx;
     int n_vocab  = model.hparams.n_vocab;
 
+    std::vector<uint8_t> work_buffer;
+
     for (int ex=0; ex<n_examples; ++ex) {
         struct ggml_init_params params = {
             /*.mem_size   =*/ compute_size,
@@ -1586,7 +1599,6 @@ int main(int argc, char ** argv) {
         int n_past = 0;
 
         ggml_cgraph gf = {};
-        gf.n_threads = 1;
 
         get_example_targets_batch(ctx0, 64*ex+0,  tokens_input, targets);
 
@@ -1595,7 +1607,7 @@ int main(int argc, char ** argv) {
         struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
 
         ggml_build_forward_expand(&gf, e);
-        ggml_graph_compute(ctx0, &gf);
+        ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
 
         float error_before_opt = ggml_get_f32_1d(e, 0);
 
@@ -1611,7 +1623,7 @@ int main(int argc, char ** argv) {
         ggml_opt(ctx0, opt_params_lbfgs, e);
         //
         ggml_build_forward_expand(&gf, e);
-        ggml_graph_compute(ctx0, &gf);
+        ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
 
         float error_after_opt = ggml_get_f32_1d(e, 0);
 
@@ -1659,13 +1671,12 @@ int main(int argc, char ** argv) {
             struct ggml_context * ctx0 = ggml_init(params);
 
             ggml_cgraph gf = {};
-            gf.n_threads = 1;
 
             int n_past = 0;
             struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
 
             ggml_build_forward_expand(&gf, logits);
-            ggml_graph_compute(ctx0, &gf);
+            ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
 
             struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
             struct ggml_tensor * probs        = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
@@ -1687,10 +1698,11 @@ int main(int argc, char ** argv) {
     }
 
     print_matrix(model.tok_embeddings);
-
     printf("done\n");
+
     // ggml_free(kv_self.ctx);
     // ggml_free(model_lora.ctx);
     ggml_free(model.ctx);
+
     return 0;
 }
index 39d15caeb777994f9458c1681753fa74ea0f3d2a..f7215f43bb31ce88c8d16977a8ae30a9a08930aa 100644 (file)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
+void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
+    struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
+
+    if (plan.work_size > 0) {
+        buf.resize(plan.work_size);
+        plan.work_data = buf.data();
+    }
+
+    ggml_graph_compute(graph, &plan);
+}
+
 float tensor_sum_elements(const ggml_tensor * tensor) {
     float sum = 0;
     if (tensor->type==GGML_TYPE_F32) {
@@ -159,13 +170,14 @@ int main(int argc, char ** argv)  {
     // printf("Creating compute graph\n");
     struct ggml_cgraph gf = ggml_build_forward(m11xm2);
 
-    gf.n_threads=benchmark_params.n_threads;
-    printf("cgraph->n_threads=%i\n",gf.n_threads);
+    printf("n_threads=%i\n", benchmark_params.n_threads);
 
     TENSOR_DUMP(m11);
     TENSOR_DUMP(m2);
 
-    ggml_graph_compute(ctx, &gf);
+    std::vector<uint8_t> work_buffer;
+
+    ggml_graph_compute_helper(work_buffer, &gf, benchmark_params.n_threads);
 
     TENSOR_DUMP(gf.nodes[0]);
 
@@ -187,7 +199,6 @@ int main(int argc, char ** argv)  {
 
     // printf("Creating compute graph\n");
     struct ggml_cgraph gf31 = ggml_build_forward(q31);
-    gf31.n_threads=benchmark_params.n_threads;
 
     // Set up a second graph computation to make sure we override the CPU cache lines
     // printf("Creating new tensor q12 & Running quantize\n");
@@ -199,8 +210,7 @@ int main(int argc, char ** argv)  {
 
     //printf("Creating compute graph\n");
     struct ggml_cgraph gf32 = ggml_build_forward(q32);
-    gf32.n_threads=benchmark_params.n_threads;
-    printf("cgraph->n_threads=%i\n",gf31.n_threads);
+    printf("n_threads=%i\n", benchmark_params.n_threads);
 
     const int dimx = sizex;
     const int dimy = sizey;
@@ -221,14 +231,15 @@ int main(int argc, char ** argv)  {
 
         long long int start = ggml_time_us();
         //printf("Running ggml_graph_compute\n");
-        ggml_graph_compute(ctx, &gf31);
+        ggml_graph_compute_helper(work_buffer, &gf31, benchmark_params.n_threads);
+
         long long int stop = ggml_time_us();
         long long int usec = stop-start;
         double gflops = (double)(flops_per_matrix)/usec/1000.0;
         gflops_sum += gflops;
         printf("%9i;%8i;%6i;%6i;%6i;%15lli;%18lli;%10.2f\n",
             i,
-            gf31.n_threads,
+            benchmark_params.n_threads,
             sizex, sizey, sizez, flops_per_matrix,
             usec,gflops);
 
@@ -253,7 +264,7 @@ int main(int argc, char ** argv)  {
         }
 
         // Running a different graph computation to make sure we override the CPU cache lines
-        ggml_graph_compute(ctx, &gf32);
+        ggml_graph_compute_helper(work_buffer, &gf32, benchmark_params.n_threads);
     }
     printf("\n");
     printf("Average%78.2f\n",gflops_sum/((double)benchmark_params.n_iterations));
index cdfe4bfe97865edaaa5a74978454cb3d6060e8f4..7438defdefcdfeada492d9ae758695c914ea9b6b 100644 (file)
@@ -35,10 +35,9 @@ int main(int argc, char ** argv) {
     struct ggml_context * ctx_eval = NULL;
 
     struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
-    gf.n_threads = 1;
 
     // this allocates all Metal resources and memory buffers
-    auto * ctx_metal = ggml_metal_init();
+    auto * ctx_metal = ggml_metal_init(1);
 
     const size_t max_size_data = ggml_get_max_tensor_size(ctx_data);
     const size_t max_size_eval = ggml_get_max_tensor_size(ctx_eval);
index c50eeb343bcef2dad8c93caf6c85307a1c071320..b96fdcdc44b57f3618ecbba8a81454902259d876 100644 (file)
@@ -60,6 +60,17 @@ float frand_uniform(struct random_uniform_distribution * rnd) {
     return rnd->rd(rnd->gen);
 }
 
+void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
+    struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
+
+    if (plan.work_size > 0) {
+        buf.resize(plan.work_size);
+        plan.work_data = buf.data();
+    }
+
+    ggml_graph_compute(graph, &plan);
+}
+
 struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
     float scale = 1.0f; // xavier
     switch (tensor->n_dims) {
@@ -1426,11 +1437,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
 
     gf->n_nodes = 0;
     gf->n_leafs = 0;
-    gf->work_size = 0;
     gf->perf_runs = 0;
     gf->perf_cycles = 0;
     gf->perf_time_us = 0;
-    gf->work = NULL;
 
     const auto & hparams = model->hparams;
     //const int n_ctx      = hparams.n_ctx;
@@ -3162,6 +3171,7 @@ int main(int argc, char ** argv) {
     printf("used_mem model+cache: %zu bytes\n", ggml_used_mem(model.ctx));
     // ggml_print_tensor_objects(model.ctx);
 
+    // TODO: use std::vector<uint8_t> intead of "new"
     size_t    compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb);
     uint8_t * compute_addr = new uint8_t[compute_size];
 
@@ -3183,6 +3193,8 @@ int main(int argc, char ** argv) {
         GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size());
     }
 
+    std::vector<uint8_t> work_buffer;
+
     printf("%s: begin training\n", __func__);
 
     for (int ex = 0; ex < params.n_examples; ++ex) {
@@ -3217,9 +3229,6 @@ int main(int argc, char ** argv) {
         struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
         struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
 
-        // ggml_cgraph gf = {};
-        gf->n_threads = params.n_threads;
-        gb->n_threads = params.n_threads;
 
         get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex,  tokens_input, target_logits, target_probs);
 
@@ -3248,7 +3257,7 @@ int main(int argc, char ** argv) {
             *gb = ggml_build_backward(ctx0, gf, true);
         }
 
-        ggml_graph_compute(ctx0, gf);
+        ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
 
         size_t used_mem_before_opt = ggml_used_mem(ctx0);
 
@@ -3272,7 +3281,7 @@ int main(int argc, char ** argv) {
         model.train_samples += n_batch;
         model.train_tokens  += n_batch * n_tokens;
 
-        ggml_graph_compute(ctx0, gf);
+        ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
 
         float error_after_opt = ggml_get_f32_1d(loss, 0);
 
@@ -3354,13 +3363,12 @@ int main(int argc, char ** argv) {
             struct ggml_context * ctx0 = ggml_init(cparams);
 
             ggml_cgraph gf = {};
-            gf.n_threads = params.n_threads;
 
             int n_past = 0;
             struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
 
             ggml_build_forward_expand(&gf, logits);
-            ggml_graph_compute(ctx0, &gf);
+            ggml_graph_compute_helper(work_buffer, &gf, params.n_threads);
 
             //struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
             //struct ggml_tensor * probs        = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
@@ -3386,6 +3394,7 @@ int main(int argc, char ** argv) {
     delete[] compute_addr;
     delete[] compute_buf_0;
     delete[] compute_buf_1;
+
     llama_free(lctx);
     llama_free_model(lmodel);
     ggml_free(model.ctx);
index b9e50ac745eb077f927a51fef6ab1af65de1e032..928f1705c381cf710468968b3f5a66e19e5a0c47 100644 (file)
@@ -34,9 +34,13 @@ extern "C" {
 
 struct ggml_metal_context;
 
-struct ggml_metal_context * ggml_metal_init(void);
+// number of command buffers to use
+struct ggml_metal_context * ggml_metal_init(int n_cb);
 void ggml_metal_free(struct ggml_metal_context * ctx);
 
+// set the number of command buffers to use
+void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
+
 // creates a mapping between a host memory buffer and a device memory buffer
 // - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
 // - the mapping is used during computation to determine the arguments of the compute kernels
index fd69c41fe357d6e1636f2640186c7908352450b9..3f15f791f9f65f8490bf50e276766dae71cf07e7 100644 (file)
@@ -25,6 +25,8 @@ struct ggml_metal_buffer {
 };
 
 struct ggml_metal_context {
+    int n_cb;
+
     float * logits;
 
     id<MTLDevice>       device;
@@ -86,11 +88,12 @@ static NSString * const msl_library_source = @"see metal.metal";
 @implementation GGMLMetalClass
 @end
 
-struct ggml_metal_context * ggml_metal_init(void) {
+struct ggml_metal_context * ggml_metal_init(int n_cb) {
     fprintf(stderr, "%s: allocating\n", __func__);
 
     struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
 
+    ctx->n_cb   = n_cb;
     ctx->device = MTLCreateSystemDefaultDevice();
     ctx->queue  = [ctx->device newCommandQueue];
     ctx->n_buffers = 0;
@@ -208,6 +211,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     free(ctx);
 }
 
+void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
+    ctx->n_cb = n_cb;
+}
+
 // finds the Metal buffer that contains the tensor data on the GPU device
 // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
 // Metal buffer based on the host memory pointer
@@ -354,7 +361,7 @@ void ggml_metal_graph_compute(
     // create multiple command buffers and enqueue them
     // then, we encode the graph into the command buffers in parallel
 
-    const int n_cb = gf->n_threads;
+    const int n_cb = ctx->n_cb;
 
     NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
 
diff --git a/ggml.c b/ggml.c
index 4ba7ac9313820a1d5cd0398c2454c3859d3e583c..55b0aff03bf166d6629eb46e9de07ef7d486ddfe 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4583,14 +4583,13 @@ struct ggml_tensor * ggml_new_tensor_impl(
         /*.src0         =*/ NULL,
         /*.src1         =*/ NULL,
         /*.opt          =*/ { NULL },
-        /*.n_tasks      =*/ 0,
         /*.perf_runs    =*/ 0,
         /*.perf_cycles  =*/ 0,
         /*.perf_time_us =*/ 0,
         /*.data         =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
         /*.name         =*/ { 0 },
         /*.extra        =*/ NULL,
-        /*.pad          =*/ { 0 },
+        /*.padding      =*/ { 0 },
     };
 
     // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
@@ -10718,8 +10717,6 @@ static void ggml_compute_forward_mul_mat(
 
         float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
 
-        assert(ne00 % 32 == 0);
-
         for (int64_t ic = 0; ic < ne11; ++ic) {
             vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
         }
@@ -15772,9 +15769,6 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
     struct ggml_cgraph result = {
         /*.n_nodes      =*/ 0,
         /*.n_leafs      =*/ 0,
-        /*.n_threads    =*/ GGML_DEFAULT_N_THREADS,
-        /*.work_size    =*/ 0,
-        /*.work         =*/ NULL,
         /*.nodes        =*/ { NULL },
         /*.grads        =*/ { NULL },
         /*.leafs        =*/ { NULL },
@@ -15945,12 +15939,13 @@ void clear_numa_thread_affinity(void) {}
 #endif
 
 struct ggml_compute_state_shared {
-    struct ggml_cgraph * cgraph;
+    const struct ggml_cgraph * cgraph;
+    const struct ggml_cplan  * cplan;
 
     int64_t perf_node_start_cycles;
     int64_t perf_node_start_time_us;
 
-    int n_threads;
+    const int n_threads;
 
     // synchronization primitives
     atomic_int n_active; // num active threads
@@ -15974,9 +15969,13 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
 
 static thread_ret_t ggml_graph_compute_thread(void * data) {
     struct ggml_compute_state * state = (struct ggml_compute_state *) data;
-    struct ggml_cgraph * cgraph = state->shared->cgraph;
 
-    const int n_threads = state->shared->n_threads;
+    const struct ggml_cgraph * cgraph = state->shared->cgraph;
+    const struct ggml_cplan  * cplan  = state->shared->cplan;
+
+    const int * n_tasks_arr = cplan->n_tasks;
+    const int   n_threads   = state->shared->n_threads;
+
     set_numa_thread_affinity(state->ith, n_threads);
 
     int node_n = -1;
@@ -15989,15 +15988,15 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                 /*.type  =*/ GGML_TASK_FINALIZE,
                 /*.ith   =*/ 0,
                 /*.nth   =*/ 0,
-                /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-                /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+                /*.wsize =*/ cplan->work_size,
+                /*.wdata =*/ cplan->work_data,
             };
 
             if (node_n != -1) {
                 /* FINALIZE */
                 struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
                 if (GGML_OP_HAS_FINALIZE[node->op]) {
-                    params.nth = node->n_tasks;
+                    params.nth = n_tasks_arr[node_n];
                     ggml_compute_forward(&params, node);
                     ggml_graph_compute_perf_stats_node(node, state->shared);
                 }
@@ -16008,11 +16007,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                 GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
 
                 struct ggml_tensor * node = cgraph->nodes[node_n];
+                const int n_tasks = n_tasks_arr[node_n];
 
                 state->shared->perf_node_start_cycles  = ggml_perf_cycles();
                 state->shared->perf_node_start_time_us = ggml_perf_time_us();
 
-                params.nth = node->n_tasks;
+                params.nth = n_tasks;
 
                 /* INIT */
                 if (GGML_OP_HAS_INIT[node->op]) {
@@ -16020,7 +16020,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                     ggml_compute_forward(&params, node);
                 }
 
-                if (node->n_tasks == 1) {
+                if (n_tasks == 1) {
                     // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
                     // they do something more efficient than spinning (?)
                     params.type = GGML_TASK_COMPUTE;
@@ -16052,16 +16052,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
 
         /* COMPUTE */
         struct ggml_tensor * node = cgraph->nodes[node_n];
+        const int n_tasks = n_tasks_arr[node_n];
 
         struct ggml_compute_params params = {
             /*.type  =*/ GGML_TASK_COMPUTE,
             /*.ith   =*/ state->ith,
-            /*.nth   =*/ node->n_tasks,
-            /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-            /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+            /*.nth   =*/ n_tasks,
+            /*.wsize =*/ cplan->work_size,
+            /*.wdata =*/ cplan->work_data,
         };
 
-        if (state->ith < node->n_tasks) {
+        if (state->ith < n_tasks) {
             ggml_compute_forward(&params, node);
         }
     }
@@ -16069,349 +16070,372 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
     return 0;
 }
 
-void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
-    const int n_threads = cgraph->n_threads;
+struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
+    if (n_threads <= 0) {
+        n_threads = GGML_DEFAULT_N_THREADS;
+    }
 
-    struct ggml_compute_state_shared state_shared = {
-        /*.cgraph                  =*/ cgraph,
-        /*.perf_node_start_cycles  =*/ 0,
-        /*.perf_node_start_time_us =*/ 0,
-        /*.n_threads               =*/ n_threads,
-        /*.n_active                =*/ n_threads,
-        /*.node_n                  =*/ -1,
-    };
-    struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
+    size_t work_size = 0;
 
-    // initialize tasks + work buffer
-    {
-        size_t work_size = 0;
+    struct ggml_cplan cplan;
+    memset(&cplan, 0, sizeof(struct ggml_cplan));
 
-        // thread scheduling for the different operations
-        for (int i = 0; i < cgraph->n_nodes; i++) {
-            struct ggml_tensor * node = cgraph->nodes[i];
+    // thread scheduling for the different operations + work buffer size estimation
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        int n_tasks = 1;
 
-            switch (node->op) {
-                case GGML_OP_CPY:
-                case GGML_OP_DUP:
-                    {
-                        node->n_tasks = n_threads;
+        struct ggml_tensor * node = cgraph->nodes[i];
 
-                        size_t cur = 0;
-                        if (ggml_is_quantized(node->type)) {
-                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
-                        }
+        switch (node->op) {
+            case GGML_OP_CPY:
+            case GGML_OP_DUP:
+                {
+                    n_tasks = n_threads;
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_ADD:
-                case GGML_OP_ADD1:
-                    {
-                        node->n_tasks = n_threads;
+                    size_t cur = 0;
+                    if (ggml_is_quantized(node->type)) {
+                        cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
+                    }
 
-                        size_t cur = 0;
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_ADD:
+            case GGML_OP_ADD1:
+                {
+                    n_tasks = n_threads;
 
-                        if (ggml_is_quantized(node->src0->type)) {
-                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
-                        }
+                    size_t cur = 0;
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_ACC:
-                    {
-                        node->n_tasks = n_threads;
+                    if (ggml_is_quantized(node->src0->type)) {
+                        cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_tasks;
+                    }
 
-                        size_t cur = 0;
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_ACC:
+                {
+                    n_tasks = n_threads;
 
-                        if (ggml_is_quantized(node->src0->type)) {
-                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_threads;
-                        }
+                    size_t cur = 0;
+
+                    if (ggml_is_quantized(node->src0->type)) {
+                        cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_tasks;
+                    }
+
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_SUB:
+            case GGML_OP_DIV:
+            case GGML_OP_SQR:
+            case GGML_OP_SQRT:
+            case GGML_OP_LOG:
+            case GGML_OP_SUM:
+            case GGML_OP_SUM_ROWS:
+            case GGML_OP_MEAN:
+            case GGML_OP_ARGMAX:
+            case GGML_OP_REPEAT:
+            case GGML_OP_REPEAT_BACK:
+            case GGML_OP_ABS:
+            case GGML_OP_SGN:
+            case GGML_OP_NEG:
+            case GGML_OP_STEP:
+            case GGML_OP_TANH:
+            case GGML_OP_ELU:
+            case GGML_OP_RELU:
+                {
+                    n_tasks = 1;
+                } break;
+            case GGML_OP_MUL:
+            case GGML_OP_GELU:
+            case GGML_OP_GELU_QUICK:
+            case GGML_OP_SILU:
+            case GGML_OP_SILU_BACK:
+            case GGML_OP_NORM:
+            case GGML_OP_RMS_NORM:
+            case GGML_OP_RMS_NORM_BACK:
+                {
+                    n_tasks = n_threads;
+                } break;
+            case GGML_OP_MUL_MAT:
+            case GGML_OP_OUT_PROD:
+                {
+                    n_tasks = n_threads;
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_SUB:
-                case GGML_OP_DIV:
-                case GGML_OP_SQR:
-                case GGML_OP_SQRT:
-                case GGML_OP_LOG:
-                case GGML_OP_SUM:
-                case GGML_OP_SUM_ROWS:
-                case GGML_OP_MEAN:
-                case GGML_OP_ARGMAX:
-                case GGML_OP_REPEAT:
-                case GGML_OP_REPEAT_BACK:
-                case GGML_OP_ABS:
-                case GGML_OP_SGN:
-                case GGML_OP_NEG:
-                case GGML_OP_STEP:
-                case GGML_OP_TANH:
-                case GGML_OP_ELU:
-                case GGML_OP_RELU:
-                    {
-                        node->n_tasks = 1;
-                    } break;
-                case GGML_OP_MUL:
-                case GGML_OP_GELU:
-                case GGML_OP_GELU_QUICK:
-                case GGML_OP_SILU:
-                case GGML_OP_SILU_BACK:
-                case GGML_OP_NORM:
-                case GGML_OP_RMS_NORM:
-                case GGML_OP_RMS_NORM_BACK:
-                    {
-                        node->n_tasks = n_threads;
-                    } break;
-                case GGML_OP_MUL_MAT:
-                case GGML_OP_OUT_PROD:
-                    {
-                        node->n_tasks = n_threads;
-
-                        // TODO: use different scheduling for different matrix sizes
-                        //const int nr0 = ggml_nrows(node->src0);
-                        //const int nr1 = ggml_nrows(node->src1);
-
-                        //node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
-                        //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
-
-                        size_t cur = 0;
-                        const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
+                    // TODO: use different scheduling for different matrix sizes
+                    //const int nr0 = ggml_nrows(node->src0);
+                    //const int nr1 = ggml_nrows(node->src1);
+
+                    //n_tasks = MIN(n_threads, MAX(1, nr0/128));
+                    //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
+
+                    size_t cur = 0;
+                    const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
 
 #if defined(GGML_USE_CUBLAS)
-                        if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
-                            node->n_tasks = 1; // TODO: this actually is doing nothing
-                                                //       the threads are still spinning
-                        }
-                        else
+                    if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
+                        n_tasks = 1; // TODO: this actually is doing nothing
+                                     //       the threads are still spinning
+                    } else
 #elif defined(GGML_USE_CLBLAST)
-                        if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
-                            node->n_tasks = 1; // TODO: this actually is doing nothing
-                                                //       the threads are still spinning
-                            cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
-                        }
-                        else
+                    if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
+                        n_tasks = 1; // TODO: this actually is doing nothing
+                                     //       the threads are still spinning
+                        cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
+                    } else
 #endif
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                        if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                            node->n_tasks = 1; // TODO: this actually is doing nothing
-                                               //       the threads are still spinning
-                            if (node->src0->type != GGML_TYPE_F32) {
-                                // here we need memory just for single 2D matrix from src0
-                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-                            }
-                        } else
-#endif
-                        if (node->src1->type != vec_dot_type) {
-                            cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
-                        } else {
-                            cur = 0;
+                    if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                        n_tasks = 1; // TODO: this actually is doing nothing
+                                     //       the threads are still spinning
+                        if (node->src0->type != GGML_TYPE_F32) {
+                            // here we need memory just for single 2D matrix from src0
+                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
                         }
+                    } else
+#endif
+                    if (node->src1->type != vec_dot_type) {
+                        cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
+                    } else {
+                        cur = 0;
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_SCALE:
-                    {
-                        node->n_tasks = 1;
-                    } break;
-                case GGML_OP_SET:
-                case GGML_OP_CONT:
-                case GGML_OP_RESHAPE:
-                case GGML_OP_VIEW:
-                case GGML_OP_PERMUTE:
-                case GGML_OP_TRANSPOSE:
-                case GGML_OP_GET_ROWS:
-                case GGML_OP_GET_ROWS_BACK:
-                case GGML_OP_DIAG:
-                case GGML_OP_DIAG_MASK_ZERO:
-                    {
-                        node->n_tasks = 1;
-                    } break;
-                case GGML_OP_DIAG_MASK_INF:
-                case GGML_OP_SOFT_MAX:
-                case GGML_OP_SOFT_MAX_BACK:
-                case GGML_OP_ROPE:
-                case GGML_OP_ROPE_BACK:
-                    {
-                        node->n_tasks = n_threads;
-                    } break;
-                case GGML_OP_ALIBI:
-                    {
-                        node->n_tasks = 1; //TODO
-                    } break;
-                case GGML_OP_CLAMP:
-                    {
-                        node->n_tasks = 1; //TODO
-                    } break;
-                case GGML_OP_CONV_1D:
-                    {
-                        node->n_tasks = n_threads;
-
-                        GGML_ASSERT(node->src0->ne[3] == 1);
-                        GGML_ASSERT(node->src1->ne[2] == 1);
-                        GGML_ASSERT(node->src1->ne[3] == 1);
-
-                        size_t cur = 0;
-                        const int nk = node->src0->ne[0];
-
-                        if (node->src0->type == GGML_TYPE_F16 &&
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_SCALE:
+                {
+                    n_tasks = 1;
+                } break;
+            case GGML_OP_SET:
+            case GGML_OP_CONT:
+            case GGML_OP_RESHAPE:
+            case GGML_OP_VIEW:
+            case GGML_OP_PERMUTE:
+            case GGML_OP_TRANSPOSE:
+            case GGML_OP_GET_ROWS:
+            case GGML_OP_GET_ROWS_BACK:
+            case GGML_OP_DIAG:
+            case GGML_OP_DIAG_MASK_ZERO:
+                {
+                    n_tasks = 1;
+                } break;
+            case GGML_OP_DIAG_MASK_INF:
+            case GGML_OP_SOFT_MAX:
+            case GGML_OP_SOFT_MAX_BACK:
+            case GGML_OP_ROPE:
+            case GGML_OP_ROPE_BACK:
+                {
+                    n_tasks = n_threads;
+                } break;
+            case GGML_OP_ALIBI:
+                {
+                    n_tasks = 1; //TODO
+                } break;
+            case GGML_OP_CLAMP:
+                {
+                    n_tasks = 1; //TODO
+                } break;
+            case GGML_OP_CONV_1D:
+                {
+                    n_tasks = n_threads;
+
+                    GGML_ASSERT(node->src0->ne[3] == 1);
+                    GGML_ASSERT(node->src1->ne[2] == 1);
+                    GGML_ASSERT(node->src1->ne[3] == 1);
+
+                    size_t cur = 0;
+                    const int nk = node->src0->ne[0];
+
+                    if (node->src0->type == GGML_TYPE_F16 &&
                             node->src1->type == GGML_TYPE_F32) {
-                            cur = sizeof(ggml_fp16_t)*(
-                                    nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
-                                    ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
-                                    );
-                        } else if (node->src0->type == GGML_TYPE_F32 &&
-                                   node->src1->type == GGML_TYPE_F32) {
-                            cur = sizeof(float)*(
-                                    nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
-                                    ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
-                                    );
-                        } else {
-                            GGML_ASSERT(false);
-                        }
+                        cur = sizeof(ggml_fp16_t)*(
+                                nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
+                                ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+                                );
+                    } else if (node->src0->type == GGML_TYPE_F32 &&
+                            node->src1->type == GGML_TYPE_F32) {
+                        cur = sizeof(float)*(
+                                nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
+                                ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+                                );
+                    } else {
+                        GGML_ASSERT(false);
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_CONV_2D:
-                    {
-                        node->n_tasks = n_threads;
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_CONV_2D:
+                {
+                    n_tasks = n_threads;
 
-                        GGML_ASSERT(node->src1->ne[3] == 1);
+                    GGML_ASSERT(node->src1->ne[3] == 1);
 
-                        const int64_t ne00 = node->src0->ne[0]; // W
-                        const int64_t ne01 = node->src0->ne[1]; // H
-                        const int64_t ne02 = node->src0->ne[2]; // C
-                        const int64_t ne03 = node->src0->ne[3]; // N
+                    const int64_t ne00 = node->src0->ne[0]; // W
+                    const int64_t ne01 = node->src0->ne[1]; // H
+                    const int64_t ne02 = node->src0->ne[2]; // C
+                    const int64_t ne03 = node->src0->ne[3]; // N
 
-                        const int64_t ne10 = node->src1->ne[0]; // W
-                        const int64_t ne11 = node->src1->ne[1]; // H
-                        const int64_t ne12 = node->src1->ne[2]; // C
+                    const int64_t ne10 = node->src1->ne[0]; // W
+                    const int64_t ne11 = node->src1->ne[1]; // H
+                    const int64_t ne12 = node->src1->ne[2]; // C
 
-                        const int64_t nk = ne00*ne01;
+                    const int64_t nk = ne00*ne01;
 
-                        UNUSED(ne02);
-                        UNUSED(ne03);
-                        UNUSED(nk);
+                    UNUSED(ne02);
+                    UNUSED(ne03);
+                    UNUSED(nk);
 
-                        size_t cur = 0;
+                    size_t cur = 0;
 
-                        if (node->src0->type == GGML_TYPE_F16 &&
+                    if (node->src0->type == GGML_TYPE_F16 &&
                             node->src1->type == GGML_TYPE_F32) {
-                            cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
-                        } else if (node->src0->type == GGML_TYPE_F32 &&
-                                   node->src1->type == GGML_TYPE_F32) {
-                            cur = sizeof(float)*      (ne10*ne11*ne12);
-                        } else {
-                            GGML_ASSERT(false);
-                        }
+                        cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
+                    } else if (node->src0->type == GGML_TYPE_F32 &&
+                            node->src1->type == GGML_TYPE_F32) {
+                        cur = sizeof(float)*      (ne10*ne11*ne12);
+                    } else {
+                        GGML_ASSERT(false);
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_FLASH_ATTN:
-                    {
-                        node->n_tasks = n_threads;
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_FLASH_ATTN:
+                {
+                    n_tasks = n_threads;
 
-                        size_t cur = 0;
+                    size_t cur = 0;
 
-                        const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+                    const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
 
-                        if (node->src1->type == GGML_TYPE_F32) {
-                            cur  = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
-                        }
+                    if (node->src1->type == GGML_TYPE_F32) {
+                        cur  = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
+                    }
 
-                        if (node->src1->type == GGML_TYPE_F16) {
-                            cur  = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
-                        }
+                    if (node->src1->type == GGML_TYPE_F16) {
+                        cur  = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_FLASH_FF:
-                    {
-                        node->n_tasks = n_threads;
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_FLASH_FF:
+                {
+                    n_tasks = n_threads;
 
-                        size_t cur = 0;
+                    size_t cur = 0;
 
-                        if (node->src1->type == GGML_TYPE_F32) {
-                            cur  = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
-                        }
+                    if (node->src1->type == GGML_TYPE_F32) {
+                        cur  = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
+                    }
 
-                        if (node->src1->type == GGML_TYPE_F16) {
-                            cur  = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
-                        }
+                    if (node->src1->type == GGML_TYPE_F16) {
+                        cur  = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_FLASH_ATTN_BACK:
-                    {
-                        node->n_tasks = n_threads;
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_FLASH_ATTN_BACK:
+                {
+                    n_tasks = n_threads;
 
-                        size_t cur = 0;
+                    size_t cur = 0;
 
-                        const int64_t    D = node->src0->ne[0];
-                        const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
-                        const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
-                        if (node->src1->type == GGML_TYPE_F32) {
-                            cur  = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
-                        }
+                    const int64_t    D = node->src0->ne[0];
+                    const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+                    const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
+                    if (node->src1->type == GGML_TYPE_F32) {
+                        cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+                    }
 
-                        if (node->src1->type == GGML_TYPE_F16) {
-                            cur  = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
-                        }
+                    if (node->src1->type == GGML_TYPE_F16) {
+                        cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_WIN_PART:
-                case GGML_OP_WIN_UNPART:
-                case GGML_OP_MAP_UNARY:
-                case GGML_OP_MAP_BINARY:
-                case GGML_OP_MAP_CUSTOM1:
-                case GGML_OP_MAP_CUSTOM2:
-                case GGML_OP_MAP_CUSTOM3:
-                    {
-                        node->n_tasks = 1;
-                    } break;
-                case GGML_OP_CROSS_ENTROPY_LOSS:
-                    {
-                        node->n_tasks = n_threads;
-
-                        size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
-
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
-                    {
-                        node->n_tasks = n_threads;
-
-                        size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
-
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_NONE:
-                    {
-                        node->n_tasks = 1;
-                    } break;
-                case GGML_OP_COUNT:
-                    {
-                        GGML_ASSERT(false);
-                    } break;
-            }
-        }
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_WIN_PART:
+            case GGML_OP_WIN_UNPART:
+            case GGML_OP_MAP_UNARY:
+            case GGML_OP_MAP_BINARY:
+            case GGML_OP_MAP_CUSTOM1:
+            case GGML_OP_MAP_CUSTOM2:
+            case GGML_OP_MAP_CUSTOM3:
+                {
+                    n_tasks = 1;
+                } break;
+            case GGML_OP_CROSS_ENTROPY_LOSS:
+                {
+                    n_tasks = n_threads;
+
+                    size_t cur = ggml_type_size(node->type)*(n_tasks + node->src0->ne[0]*n_tasks);
+
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+                {
+                    n_tasks = n_threads;
 
-        if (cgraph->work != NULL && work_size > cgraph->work_size) {
-            GGML_ASSERT(false); // TODO: better handling
+                    size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*n_tasks;
+
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_NONE:
+                {
+                    n_tasks = 1;
+                } break;
+            case GGML_OP_COUNT:
+                {
+                    GGML_ASSERT(false);
+                } break;
         }
 
-        if (work_size > 0 && cgraph->work == NULL) {
-            cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
+        cplan.n_tasks[i] = n_tasks;
+    }
+
+    if (work_size > 0) {
+        work_size += CACHE_LINE_SIZE*(n_threads - 1);
+    }
+
+    cplan.n_threads = n_threads;
+    cplan.work_size = work_size;
+    cplan.work_data = NULL;
+
+    return cplan;
+}
+
+void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
+    {
+        GGML_ASSERT(cplan);
+        GGML_ASSERT(cplan->n_threads > 0);
+
+        if (cplan->work_size > 0) {
+            GGML_ASSERT(cplan->work_data);
+        }
 
-            GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
-            cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
+        for (int i = 0; i < cgraph->n_nodes; ++i) {
+            if (cgraph->nodes[i]->op != GGML_OP_NONE) {
+                GGML_ASSERT(cplan->n_tasks[i] > 0);
+            }
         }
     }
 
+    const int n_threads = cplan->n_threads;
+
+    struct ggml_compute_state_shared state_shared = {
+        /*.cgraph                  =*/ cgraph,
+        /*.cgraph_plan             =*/ cplan,
+        /*.perf_node_start_cycles  =*/ 0,
+        /*.perf_node_start_time_us =*/ 0,
+        /*.n_threads               =*/ n_threads,
+        /*.n_active                =*/ n_threads,
+        /*.node_n                  =*/ -1,
+    };
+    struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
+
     // create thread pool
     if (n_threads > 1) {
         for (int j = 1; j < n_threads; ++j) {
@@ -16473,6 +16497,17 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
     }
 }
 
+void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
+    struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
+
+    struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
+    GGML_ASSERT(buf);
+
+    cplan.work_data = buf->data;
+
+    ggml_graph_compute(cgraph, &cplan);
+}
+
 struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
     for (int i = 0; i < cgraph->n_leafs; i++) {
         struct ggml_tensor * leaf = cgraph->leafs[i];
@@ -16511,14 +16546,13 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
     const int64_t * ne = tensor->ne;
     const size_t  * nb = tensor->nb;
 
-    fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
+    fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
             arg,
             ggml_type_name(tensor->type),
             ggml_op_name  (tensor->op),
             tensor->n_dims,
             ne[0], ne[1], ne[2], ne[3],
             nb[0], nb[1], nb[2], nb[3],
-            tensor->n_tasks,
             tensor->data,
             tensor->name);
 }
@@ -17254,9 +17288,6 @@ static enum ggml_opt_result ggml_opt_adam(
         struct ggml_cgraph * gb) {
     GGML_ASSERT(ggml_is_scalar(f));
 
-    gf->n_threads = params.n_threads;
-    gb->n_threads = params.n_threads;
-
     // these will store the parameters we want to optimize
     struct ggml_tensor * ps[GGML_MAX_PARAMS];
 
@@ -17303,7 +17334,8 @@ static enum ggml_opt_result ggml_opt_adam(
     // compute the function value
     ggml_graph_reset  (gf);
     ggml_set_f32      (f->grad, 1.0f);
-    ggml_graph_compute(ctx, gb);
+
+    ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
 
     opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
     opt->adam.fx_best = opt->adam.fx_prev;
@@ -17383,7 +17415,8 @@ static enum ggml_opt_result ggml_opt_adam(
 
         ggml_graph_reset  (gf);
         ggml_set_f32      (f->grad, 1.0f);
-        ggml_graph_compute(ctx, gb);
+
+        ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
 
         const float fx = ggml_get_f32_1d(f, 0);
 
@@ -17505,7 +17538,8 @@ static enum ggml_opt_result linesearch_backtracking(
 
             ggml_graph_reset  (gf);
             ggml_set_f32      (f->grad, 1.0f);
-            ggml_graph_compute(ctx, gb);
+
+            ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
 
             ggml_opt_get_grad(np, ps, g);
 
@@ -17573,9 +17607,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         }
     }
 
-    gf->n_threads = params.n_threads;
-    gb->n_threads = params.n_threads;
-
     const int m = params.lbfgs.m;
 
     // these will store the parameters we want to optimize
@@ -17627,7 +17658,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
 
         ggml_graph_reset  (gf);
         ggml_set_f32      (f->grad, 1.0f);
-        ggml_graph_compute(ctx, gb);
+
+        ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
 
         ggml_opt_get_grad(np, ps, g);
 
diff --git a/ggml.h b/ggml.h
index d0710c55591700d51f8cd7d4cd7f810861d19616..ab84bef68747e00a0a6c5e8eecbe89eb11b86bbd 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -65,7 +65,7 @@
 //       ggml_set_f32(a, 3.0f);
 //       ggml_set_f32(b, 4.0f);
 //
-//       ggml_graph_compute(ctx0, &gf);
+//       ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
 //
 //       printf("f = %f\n", ggml_get_f32_1d(f, 0));
 //
@@ -418,9 +418,6 @@ extern "C" {
         struct ggml_tensor * src1;
         struct ggml_tensor * opt[GGML_MAX_OPT];
 
-        // thread scheduling
-        int n_tasks;
-
         // performance
         int     perf_runs;
         int64_t perf_cycles;
@@ -432,19 +429,27 @@ extern "C" {
 
         void * extra; // extra things e.g. for ggml-cuda.cu
 
-        char padding[4];
+        char padding[8];
     };
 
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
 
+    // the compute plan that needs to be prepared for ggml_graph_compute()
+    // since https://github.com/ggerganov/ggml/issues/287
+    struct ggml_cplan {
+        size_t    work_size; // size of work buffer, calculated by `ggml_graph_plan()`
+        uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
+
+        int n_threads;
+
+        // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
+        int n_tasks[GGML_MAX_NODES];
+    };
+
     // computation graph
     struct ggml_cgraph {
         int n_nodes;
         int n_leafs;
-        int n_threads;
-
-        size_t work_size;
-        struct ggml_tensor * work;
 
         struct ggml_tensor * nodes[GGML_MAX_NODES];
         struct ggml_tensor * grads[GGML_MAX_NODES];
@@ -1290,15 +1295,22 @@ extern "C" {
 
     GGML_API void ggml_set_param(
             struct ggml_context * ctx,
-            struct ggml_tensor * tensor);
+            struct ggml_tensor  * tensor);
 
     GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
 
     GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
     GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
 
-    GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
-    GGML_API void ggml_graph_reset  (struct ggml_cgraph * cgraph);
+    // ggml_graph_plan() has to be called before ggml_graph_compute()
+    // when plan.work_size > 0, caller must allocate memory for plan.work_data
+    GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
+    GGML_API              void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
+    GGML_API              void ggml_graph_reset  (struct ggml_cgraph * cgraph);
+
+    // same as ggml_graph_compute() but the work data is allocated as a part of the context
+    // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
+    GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
 
     GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
 
index 02afdeb14078fe6fbdb92401b554bb3ce2bde753..ee6ec0920fc9c8e3d846ec1c93c7914d87910f6e 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -79,6 +79,25 @@ void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
     (void) tensor;
 }
 
+//
+// ggml helpers
+//
+
+static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
+    struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
+
+    if (plan.work_size > 0) {
+        buf.resize(plan.work_size);
+        plan.work_data = buf.data();
+    }
+
+    ggml_graph_compute(graph, &plan);
+}
+
+//
+// memory sizes
+//
+
 static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0()
 {
     static std::map<e_model, size_t> k_sizes = {
@@ -321,6 +340,9 @@ struct llama_context {
     // input embedding (1-dimensional array: [n_embd])
     std::vector<float> embedding;
 
+    // reusable buffer for `struct ggml_graph_plan.work_data`
+    std::vector<uint8_t> work_buffer;
+
     // memory buffers used to evaluate the model
     // TODO: move in llama_state
     llama_ctx_buffer buf_compute;
@@ -758,7 +780,6 @@ struct llama_model_loader {
 
 };
 
-
 //
 // kv cache
 //
@@ -1265,7 +1286,7 @@ static bool llama_eval_internal(
            const float * embd,
              const int   n_tokens,
              const int   n_past,
-             const int   n_threads,
+                   int   n_threads,
             const char * cgraph_fname) {
 
     LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
@@ -1306,10 +1327,11 @@ static bool llama_eval_internal(
 
     struct ggml_context * ctx0 = ggml_init(params);
 
+    ggml_cgraph gf = {};
+
     // for big prompts, if BLAS is enabled, it is better to use only one thread
     // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
-    ggml_cgraph gf = {};
-    gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
+    n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
 
     struct ggml_tensor * cur;
     struct ggml_tensor * inpL;
@@ -1593,6 +1615,7 @@ static bool llama_eval_internal(
 
 #ifdef GGML_USE_METAL
     if (lctx.ctx_metal && N == 1) {
+        ggml_metal_set_n_cb     (lctx.ctx_metal, n_threads);
         ggml_metal_graph_compute(lctx.ctx_metal, &gf);
         ggml_metal_get_tensor   (lctx.ctx_metal, cur);
     } else {
@@ -1612,10 +1635,10 @@ static bool llama_eval_internal(
             ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
         }
 
-        ggml_graph_compute(ctx0, &gf);
+        ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
     }
 #else
-    ggml_graph_compute(ctx0, &gf);
+    ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
 #endif
 
     if (cgraph_fname) {
@@ -2575,8 +2598,8 @@ void llama_free_model(struct llama_model * model) {
 }
 
 struct llama_context * llama_new_context_with_model(
-                             struct llama_model * model,
-            struct llama_context_params   params) {
+                 struct llama_model * model,
+        struct llama_context_params   params) {
 
     if (!model) {
         return nullptr;
@@ -2645,7 +2668,7 @@ struct llama_context * llama_new_context_with_model(
 #ifdef GGML_USE_METAL
     if (params.n_gpu_layers > 0) {
         // this allocates all Metal resources and memory buffers
-        ctx->ctx_metal = ggml_metal_init();
+        ctx->ctx_metal = ggml_metal_init(1);
 
         void * data_ptr  = NULL;
         size_t data_size = 0;
@@ -2802,6 +2825,9 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
     // read tensors and apply
     bool warned = false;
     int n_tensors = 0;
+
+    std::vector<uint8_t> work_buffer;
+
     while (true) {
         int32_t n_dims;
         int32_t length;
@@ -2966,8 +2992,8 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
             }
 
             struct ggml_cgraph gf = ggml_build_forward(r);
-            gf.n_threads = n_threads;
-            ggml_graph_compute(lora_ctx, &gf);
+
+            ggml_graph_compute_helper(work_buffer, &gf, n_threads);
 
             // we won't need these tensors again, reset the context to save memory
             ggml_free(lora_ctx);
@@ -3120,7 +3146,6 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
 
             ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
             ggml_cgraph gf{};
-            gf.n_threads = 1;
 
             ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
             kout3d->data = out;
@@ -3140,7 +3165,7 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
 
             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
-            ggml_graph_compute(cpy_ctx, &gf);
+            ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
 
             ggml_free(cpy_ctx);
         }
@@ -3226,7 +3251,6 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
 
             ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
             ggml_cgraph gf{};
-            gf.n_threads = 1;
 
             ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
             kin3d->data = (void *) inp;
@@ -3246,7 +3270,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
 
             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
-            ggml_graph_compute(cpy_ctx, &gf);
+            ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
 
             ggml_free(cpy_ctx);
         }
index 4171c126c7b7d00bd86fe1a4c9b14ea33a0b5277..1acf050a743e4c38638fa4364412ce019f40a449 100644 (file)
@@ -10,5 +10,5 @@ llama_add_test(test-quantize-fns.cpp)
 llama_add_test(test-quantize-perf.cpp)
 llama_add_test(test-sampling.cpp)
 llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
-llama_add_test(test-grad0.c) # SLOW
+llama_add_test(test-grad0.c) # SLOW
 # llama_add_test(test-opt.c) # SLOW
index a3e25214b84eb29766ff2a879d7d40edcdaf3fc1..da4001ce5269fbac5c9fa720396954f996662d9b 100644 (file)
@@ -10,6 +10,8 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
+#pragma GCC diagnostic ignored "-Wdouble-promotion"
+
 #define MAX_NARGS 3
 
 #undef MIN
@@ -49,7 +51,7 @@ float frand(void) {
 
 int irand(int n) {
     if (n == 0) return 0;
-    else return rand()%n;
+    return rand()%n;
 }
 
 void get_random_dims(int64_t * dims, int ndims) {
@@ -159,12 +161,14 @@ struct ggml_tensor * get_random_tensor_int(
 float get_element(const struct ggml_tensor * t, int idx) {
     if (t->type == GGML_TYPE_F32) {
         return ((float *)t->data)[idx];
-    } else if (t->type == GGML_TYPE_I32) {
+    }
+
+    if (t->type == GGML_TYPE_I32) {
         return ((int32_t *)t->data)[idx];
-    } else {
-        assert(false);
-        return INFINITY;
     }
+
+    assert(false);
+    return INFINITY;
 }
 
 void set_element(struct ggml_tensor * t, int idx, float value) {
@@ -215,15 +219,14 @@ bool check_gradient(
     }
 
     struct ggml_cgraph gf = ggml_build_forward (f);
-    gf.n_threads = n_threads;
-
     struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
-    gb.n_threads = n_threads;
 
-    ggml_graph_compute(ctx0, &gf);
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+
     ggml_graph_reset  (&gf);
     ggml_set_f32      (f->grad, 1.0f);
-    ggml_graph_compute(ctx0, &gb);
+
+    ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
     // ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
     // ggml_graph_dump_dot(&gb, &gf,  "test-grad0-backward.dot");
@@ -236,15 +239,16 @@ bool check_gradient(
             const float xm = x0 - eps;
             const float xp = x0 + eps;
             set_element(x[i], k, xp);
-            ggml_graph_compute(ctx0, &gf);
+
+            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
             const float f0 = ggml_get_f32_1d(f, 0);
 
             set_element(x[i], k, xm);
-            ggml_graph_compute(ctx0, &gf);
 
-            const float f1 = ggml_get_f32_1d(f, 0);
+            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
+            const float f1 = ggml_get_f32_1d(f, 0);
             const float g0 = (f0 - f1)/(2.0f*eps);
 
             set_element(x[i], k, x0);
@@ -252,12 +256,13 @@ bool check_gradient(
             // compute gradient using backward graph
             ggml_graph_reset  (&gf);
             ggml_set_f32      (f->grad, 1.0f);
-            ggml_graph_compute(ctx0, &gb);
+
+            ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
             const float g1 = get_element(x[i]->grad, k);
 
             const float error_abs = fabsf(g0 - g1);
-            const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabs(g0) : 0;
+            const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
 
             if (error_abs > max_error_abs || error_rel > max_error_rel) {
                 printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
index d001615ee353bfca5a02d3648f986ed636fb5fc7..e928a7df7ee68e60a9af80c46deb785d0f2a7383 100644 (file)
@@ -7,6 +7,7 @@
 
 #define MAX_NARGS 2
 
+#pragma GCC diagnostic ignored "-Wdouble-promotion"
 
 //
 // logging
@@ -33,7 +34,7 @@
 #define GGML_PRINT(...) printf(__VA_ARGS__)
 
 
-float frand() {
+float frand(void) {
     return (float)rand()/(float)RAND_MAX;
 }
 
@@ -114,7 +115,7 @@ void set_element(struct ggml_tensor * t, int idx, float value) {
     ((float *)t->data)[idx] = value;
 }
 
-int main(int argc, const char ** argv) {
+int main(void) {
     struct ggml_init_params params = {
         .mem_size   = 1024*1024*1024,
         .mem_buffer = NULL,
@@ -137,10 +138,11 @@ int main(int argc, const char ** argv) {
     struct ggml_tensor * d  = ggml_sub(ctx, c, ab);
     struct ggml_tensor * e  = ggml_sum(ctx, ggml_sqr(ctx, d));
 
-
     struct ggml_cgraph ge = ggml_build_forward(e);
-    ggml_graph_reset  (&ge);
-    ggml_graph_compute(ctx, &ge);
+    ggml_graph_reset(&ge);
+
+    ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
+
     const float fe = ggml_get_f32_1d(e, 0);
     printf("%s: e = %.4f\n", __func__, fe);
 
@@ -148,8 +150,10 @@ int main(int argc, const char ** argv) {
 
     ggml_opt(ctx, opt_params, e);
 
-    ggml_graph_reset  (&ge);
-    ggml_graph_compute(ctx, &ge);
+    ggml_graph_reset(&ge);
+
+    ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
+
     const float fe_opt = ggml_get_f32_1d(e, 0);
     printf("%s: original  e = %.4f\n", __func__, fe);
     printf("%s: optimized e = %.4f\n", __func__, fe_opt);