]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp (changes to ggml_graph_compute() API) (#368)
authorGeorgi Gerganov <redacted>
Mon, 10 Jul 2023 18:40:05 +0000 (21:40 +0300)
committerGitHub <redacted>
Mon, 10 Jul 2023 18:40:05 +0000 (21:40 +0300)
23 files changed:
examples/dolly-v2/main.cpp
examples/gpt-2/main.cpp
examples/gpt-j/main.cpp
examples/gpt-neox/main.cpp
examples/mnist/main-cpu.cpp
examples/mnist/main-mtl.cpp
examples/mnist/main.cpp
examples/mpt/main.cpp
examples/replit/main.cpp
examples/starcoder/main.cpp
examples/whisper/whisper.cpp
include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-metal.h
src/ggml-metal.m
src/ggml-opencl.cpp
src/ggml.c
tests/test-blas0.c
tests/test-grad0.c
tests/test-mul-mat0.c
tests/test-opt.c
tests/test1.c
tests/test1.zig

index 4d497b016f4c5d7d9daddc4b4f629c341173f551..ce634213b42f0fa5514e8745bcf56b2854f95892 100644 (file)
@@ -496,7 +496,6 @@ bool dollyv2_eval(
 
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph gf = { };
-    gf.n_threads = n_threads;
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@@ -659,7 +658,7 @@ bool dollyv2_eval(
 
     // run the computation
     ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute       (ctx0, &gf);
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
index c6f2c7b1265ef6cf195d37ca0eff44c6bc071350..7e12eab5fb226bed30321c5010fb59e2f3a7986a 100644 (file)
@@ -429,7 +429,6 @@ bool gpt2_eval(
 
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph gf = {};
-    gf.n_threads = n_threads;
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@@ -674,7 +673,7 @@ bool gpt2_eval(
 
     // run the computation
     ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute       (ctx0, &gf);
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
index ad99c80bc1ccdfe1fb701e458ac638b13d62b356..b42764ce274f3bd52045884c46df71ba6db7de8c 100644 (file)
@@ -425,7 +425,6 @@ bool gptj_eval(
 
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph gf = {};
-    gf.n_threads = n_threads;
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@@ -584,7 +583,7 @@ bool gptj_eval(
 
     // run the computation
     ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute       (ctx0, &gf);
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
index 065af4dae15756e3b1a6ba684ca5e62048977473..4af771b8440f6762edbd3814f8bd12c05918b35e 100644 (file)
@@ -476,7 +476,6 @@ bool gpt_neox_eval(
 
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph gf = {};
-    gf.n_threads = n_threads;
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@@ -646,7 +645,7 @@ bool gpt_neox_eval(
 
     // run the computation
     ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute       (ctx0, &gf);
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
index b3cde65818394acd3f37ba4629c35dc670540568..2000c9aac6274b44fe9bcde4902f6f431902291e 100644 (file)
@@ -41,11 +41,10 @@ int mnist_eval(
     struct ggml_context * ctx_eval = NULL;
 
     struct ggml_cgraph gfi = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
-    gfi.n_threads = n_threads;
 
     // allocate work context
     // needed during ggml_graph_compute() to allocate a work tensor
-    static size_t buf_size = gfi.work_size; // TODO
+    static size_t buf_size = 128ull*1024*1024; // TODO
     static void * buf = malloc(buf_size);
 
     struct ggml_init_params params = {
@@ -59,7 +58,7 @@ int mnist_eval(
     struct ggml_tensor * input = ggml_graph_get_tensor(&gfi, "input");
     memcpy(input->data, digit.data(), ggml_nbytes(input));
 
-    ggml_graph_compute(ctx_work, &gfi);
+    ggml_graph_compute_with_ctx(ctx_work, &gfi, n_threads);
 
     const float * probs_data = ggml_get_data_f32(ggml_graph_get_tensor(&gfi, "probs"));
 
index 1f9475d85aa509aa30b366cf99aef713a6a58070..a8d47ac9c70c315e66076aa5094fb77f1b044cdf 100644 (file)
@@ -36,16 +36,15 @@ int mnist_eval(
     struct ggml_context * ctx_eval = NULL;
 
     struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
-    gf.n_threads = 1;
 
     // allocate work context
-    static size_t buf_size = gf.work_size; // TODO
+    static size_t buf_size = 128ull*1024*1024; // TODO
     static void * buf = malloc(buf_size);
 
     struct ggml_init_params params = {
-        .mem_size   = buf_size,
-        .mem_buffer = buf,
-        .no_alloc   = false,
+        /*.mem_size   =*/ buf_size,
+        /*.mem_buffer =*/ buf,
+        /*.no_alloc   =*/ false,
     };
 
     struct ggml_context * ctx_work = ggml_init(params);
index fac641e0e61de12e99bb09c5306b60e73eb9a8ee..5ff4ac20fc93b9826c49ce36912ad60e23fcd077 100644 (file)
@@ -186,7 +186,6 @@ int mnist_eval(
 
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph gf = {};
-    gf.n_threads = n_threads;
 
     struct ggml_tensor * input = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_input);
     memcpy(input->data, digit.data(), ggml_nbytes(input));
@@ -202,7 +201,7 @@ int mnist_eval(
 
     // build / export / run the computation graph
     ggml_build_forward_expand(&gf, probs);
-    ggml_graph_compute       (ctx0, &gf);
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
     //ggml_graph_print   (&gf);
     ggml_graph_dump_dot(&gf, NULL, "mnist.dot");
index 063d85669cc7a280c503db2fcd979fbbeda413ca..457dc3d5be8a67b7aac9d5f5b0f24c8b9a3abf8f 100644 (file)
@@ -499,7 +499,6 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
 
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph gf = {};
-    gf.n_threads = n_threads;
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd));
@@ -651,7 +650,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
 
     // run the computation
     ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute(ctx0, &gf);
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
     // std::cout << "Qcur" << std::endl;
     // print_tensor(Qcur);
index 4b96ddfb7e22cc1940751f246f71fd612b14c2ea..1ed265bfb0d62b616f8f2853b3bbb50c4b3cb0ad 100644 (file)
@@ -475,7 +475,6 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa
 
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph gf = {};
-    gf.n_threads = n_threads;
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd));
@@ -614,7 +613,7 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa
 
     // run the computation
     ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute(ctx0, &gf);
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
     // std::cout << "Qcur" << std::endl;
     // print_tensor(Qcur);
index 8a694aa85a6f6b8a64185ec0b0859400a5604656..c50073045ce299ed7c0a25cbda0e01c8addcacea 100644 (file)
@@ -463,7 +463,6 @@ bool starcoder_eval(
 
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph gf = {};
-    gf.n_threads = n_threads;
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@@ -716,7 +715,7 @@ bool starcoder_eval(
 
     // run the computation
     ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute       (ctx0, &gf);
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
@@ -791,7 +790,7 @@ int main(int argc, char ** argv) {
     printf("%s: top_p          = %.3f\n", __func__, params.top_p);
     printf("%s: repeat_last_n  = %d\n",   __func__, params.repeat_last_n);
     printf("%s: repeat_penalty = %.3f\n", __func__, params.repeat_penalty);
-    
+
     int n_past = 0;
 
     int64_t t_sample_us  = 0;
@@ -801,7 +800,7 @@ int main(int argc, char ** argv) {
 
     std::vector<int32_t> last_n_tokens(model.hparams.n_ctx);
     std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
-    
+
     // tokenize the prompt
     std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
 
@@ -868,14 +867,14 @@ int main(int argc, char ** argv) {
             embd.push_back(id);
 
             last_n_tokens.erase(last_n_tokens.begin());
-            last_n_tokens.push_back(id);            
+            last_n_tokens.push_back(id);
         } else {
             // if here, it means we are still processing the input prompt
             for (int k = i; k < embd_inp.size(); k++) {
                 embd.push_back(embd_inp[k]);
 
                 last_n_tokens.erase(last_n_tokens.begin());
-                last_n_tokens.push_back(embd_inp[k]); 
+                last_n_tokens.push_back(embd_inp[k]);
 
                 if (embd.size() >= params.n_batch) {
                     break;
index 381874573ea48a9a7ee307c4e6a00711241be728..963e955da5631b32d7b175eec107c756dc426b65 100644 (file)
@@ -1782,10 +1782,9 @@ static bool whisper_encode_internal(
         // run the computation
         {
             struct ggml_cgraph gf = {};
-            gf.n_threads = n_threads;
 
             ggml_build_forward_expand(&gf, cur);
-            ggml_graph_compute(ctx0, &gf);
+            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
             //ggml_graph_print(&gf);
         }
@@ -1828,7 +1827,6 @@ static bool whisper_encode_internal(
     // pre-compute cross-attention memory
     {
         struct ggml_cgraph gf = {};
-        gf.n_threads = n_threads;
 
         // TODO: hack to disconnect the encoded features from the previous graph
         cur->op = GGML_OP_NONE;
@@ -1871,7 +1869,7 @@ static bool whisper_encode_internal(
             ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
         }
 
-        ggml_graph_compute(ctx0, &gf);
+        ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
         //ggml_graph_print(&gf);
     }
 
@@ -1942,7 +1940,6 @@ static bool whisper_decode_internal(
     struct ggml_context * ctx0 = ggml_init(params);
 
     struct ggml_cgraph gf = {};
-    gf.n_threads = n_threads;
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, tokens, N*ggml_element_size(embd));
@@ -2286,7 +2283,7 @@ static bool whisper_decode_internal(
     // run the computation
     {
         ggml_build_forward_expand(&gf, logits);
-        ggml_graph_compute       (ctx0, &gf);
+        ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
     }
 
     // extract logits for all N tokens
@@ -5098,17 +5095,15 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
 
             struct ggml_cgraph gf = ggml_build_forward(c);
 
-            gf.n_threads = n_threads;
-
             double tsum = 0.0;
 
             // heat-up
-            ggml_graph_compute(ctx0, &gf);
+            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
             for (int i = 0; i < n_max; ++i) {
                 const int64_t t0 = ggml_time_us();
 
-                ggml_graph_compute(ctx0, &gf);
+                ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
                 const int64_t t1 = ggml_time_us();
 
index d0710c55591700d51f8cd7d4cd7f810861d19616..ab84bef68747e00a0a6c5e8eecbe89eb11b86bbd 100644 (file)
@@ -65,7 +65,7 @@
 //       ggml_set_f32(a, 3.0f);
 //       ggml_set_f32(b, 4.0f);
 //
-//       ggml_graph_compute(ctx0, &gf);
+//       ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
 //
 //       printf("f = %f\n", ggml_get_f32_1d(f, 0));
 //
@@ -418,9 +418,6 @@ extern "C" {
         struct ggml_tensor * src1;
         struct ggml_tensor * opt[GGML_MAX_OPT];
 
-        // thread scheduling
-        int n_tasks;
-
         // performance
         int     perf_runs;
         int64_t perf_cycles;
@@ -432,19 +429,27 @@ extern "C" {
 
         void * extra; // extra things e.g. for ggml-cuda.cu
 
-        char padding[4];
+        char padding[8];
     };
 
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
 
+    // the compute plan that needs to be prepared for ggml_graph_compute()
+    // since https://github.com/ggerganov/ggml/issues/287
+    struct ggml_cplan {
+        size_t    work_size; // size of work buffer, calculated by `ggml_graph_plan()`
+        uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
+
+        int n_threads;
+
+        // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
+        int n_tasks[GGML_MAX_NODES];
+    };
+
     // computation graph
     struct ggml_cgraph {
         int n_nodes;
         int n_leafs;
-        int n_threads;
-
-        size_t work_size;
-        struct ggml_tensor * work;
 
         struct ggml_tensor * nodes[GGML_MAX_NODES];
         struct ggml_tensor * grads[GGML_MAX_NODES];
@@ -1290,15 +1295,22 @@ extern "C" {
 
     GGML_API void ggml_set_param(
             struct ggml_context * ctx,
-            struct ggml_tensor * tensor);
+            struct ggml_tensor  * tensor);
 
     GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
 
     GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
     GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
 
-    GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
-    GGML_API void ggml_graph_reset  (struct ggml_cgraph * cgraph);
+    // ggml_graph_plan() has to be called before ggml_graph_compute()
+    // when plan.work_size > 0, caller must allocate memory for plan.work_data
+    GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
+    GGML_API              void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
+    GGML_API              void ggml_graph_reset  (struct ggml_cgraph * cgraph);
+
+    // same as ggml_graph_compute() but the work data is allocated as a part of the context
+    // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
+    GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
 
     GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
 
index 7965ff74111f763e33d77dded69c256545cbaf70..fd36f179b61447c4567f72bf43300c9381d75d74 100644 (file)
@@ -59,8 +59,8 @@ typedef float2 dfloat2;
 #endif //GGML_CUDA_DMMV_F16
 
 typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
-typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
-typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
+typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
+typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
 typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
 typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
 typedef void (*ggml_cuda_op_t)(
@@ -131,7 +131,7 @@ typedef struct {
 } block_q8_1;
 static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
 
-typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
+typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs);
 
 //================================= k-quants
 
@@ -208,6 +208,7 @@ typedef struct {
 static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
 
 #define WARP_SIZE 32
+#define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
 #define CUDA_ADD_BLOCK_SIZE 256
 #define CUDA_MUL_BLOCK_SIZE 256
@@ -407,7 +408,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
 
 //================================== k-quants
 
-static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
+static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) {
 
     const int i   = blockIdx.x;
     const block_q2_K * x = (const block_q2_K *) vx;
@@ -440,7 +441,7 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
 
 }
 
-static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
+static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) {
 
     const int i = blockIdx.x;
     const block_q3_K * x = (const block_q3_K *) vx;
@@ -504,7 +505,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
 }
 #endif
 
-static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
+static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) {
     const block_q4_K * x = (const block_q4_K *) vx;
 
     const int i = blockIdx.x;
@@ -544,7 +545,7 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
 #endif
 }
 
-static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
+static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) {
     const block_q5_K * x = (const block_q5_K *) vx;
 
     const int i = blockIdx.x;
@@ -590,7 +591,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
 #endif
 }
 
-static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
+static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) {
     const block_q6_K * x = (const block_q6_K *) vx;
 
     const int i = blockIdx.x;
@@ -634,7 +635,7 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
 #endif
 }
 
-static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
@@ -742,7 +743,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
     }
 }
 
-static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
     const int row = blockIdx.y*blockDim.y + threadIdx.y;
     if (row > nrows) return;
@@ -846,7 +847,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
     }
 }
 
-static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
     const int row = blockIdx.y*blockDim.y + threadIdx.y;
     if (row > nrows) return;
@@ -949,7 +950,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
     }
 }
 
-static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
+static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
 
     const int row = blockIdx.x;
     const int num_blocks_per_row = ncols / QK_K;
@@ -1053,7 +1054,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
     }
 }
 
-static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
@@ -1171,7 +1172,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
     v.y = x[ib + iqs + 1];
 }
 
-static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
+static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int ndata, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
@@ -1180,10 +1181,10 @@ static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
 
     block_q8_1 * y = (block_q8_1 *) vy;
 
-    const int ib = i / QK8_0; // block index
-    const int iqs = i % QK8_0; // quant index
+    const int ib = i / QK8_1; // block index
+    const int iqs = i % QK8_1; // quant index
 
-    const float xi = x[i];
+    const float xi = i < ndata ? x[i] : 0.0f;
     float amax = fabsf(xi);
     float sum = xi;
 
@@ -1207,7 +1208,7 @@ static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
 }
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static __global__ void dequantize_block(const void * vx, float * y, const int k) {
+static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) {
     const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
 
     if (i >= k) {
@@ -1227,7 +1228,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
     y[iybs + iqs + y_offset] = v.y;
 }
 
-static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
 #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
     const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
 
@@ -1252,7 +1253,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, cons
 #endif // __CUDA_ARCH__ >= 600
 }
 
-static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
 #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
     const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
 
@@ -1277,7 +1278,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, cons
 #endif // __CUDA_ARCH__ >= 600
 }
 
-static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
 #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
     const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
 
@@ -1312,7 +1313,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, cons
 #endif // __CUDA_ARCH__ >= 600
 }
 
-static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
 #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
     const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
 
@@ -1346,7 +1347,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, cons
 #endif // __CUDA_ARCH__ >= 600
 }
 
-static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
 #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
     const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
 
@@ -1366,7 +1367,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, cons
 }
 
 template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
-static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
+static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
     const int row = blockIdx.y*blockDim.y + threadIdx.y;
 
     if (row >= nrows) {
@@ -1404,7 +1405,7 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
 }
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
+static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
     // qk = quantized weights per x block
     // qr = number of quantized weights per data value in x block
     const int row = blockIdx.y*blockDim.y + threadIdx.y;
@@ -1471,7 +1472,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
     }
 }
 
-static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
+static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
     const half * x = (const half *) vx;
 
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
@@ -1518,7 +1519,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
 }
 
 static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
-    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
+    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
     const int row_stride_x, const int channel_stride_x) {
 
     const half * x = (const half *) vx;
@@ -1714,9 +1715,9 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
     rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
 }
 
-static void quantize_row_q8_1_cuda(const float * x, void * vy, const int k, cudaStream_t stream) {
+static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
-    quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, k);
+    quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, ndata, k);
 }
 
 static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -2355,16 +2356,15 @@ inline void ggml_cuda_op_mul_mat_vec(
         src0->type == GGML_TYPE_Q5_1 ||
         src0->type == GGML_TYPE_Q8_0;
 
-    // The integer intrinsics used in mul_mat_vec_q are available with compute capability 6.
-    // However, they have bad performance with Pascal cards.
-    // Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q.
-    const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented;
+    const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 600 && mul_mat_vec_q_implemented;
 #endif
 
     if (use_mul_mat_vec_q) {
+        int64_t padded_row_size = ne00 + MATRIX_ROW_PADDING - 1;
+        padded_row_size -= padded_row_size % MATRIX_ROW_PADDING;
         size_t as;
-        void * src1_q8_1 = ggml_cuda_pool_malloc(ne00*sizeof(block_q8_1)/QK8_1, &as);
-        quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, cudaStream_main);
+        void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
+        quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main);
 
         switch (src0->type) {
             case GGML_TYPE_Q4_0:
@@ -3108,7 +3108,11 @@ void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
 
 void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
     int nrows = ggml_nrows(tensor);
+
+    const int64_t ne0 = tensor->ne[0];
+
     const size_t nb1 = tensor->nb[1];
+
     ggml_backend backend = tensor->backend;
     struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
     memset(extra, 0, sizeof(*extra));
@@ -3137,11 +3141,24 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
         int64_t nrows_split = row_high - row_low;
 
         const size_t offset_split = row_low*nb1;
-        const size_t size = ggml_nbytes_split(tensor, nrows_split);
+        size_t size = ggml_nbytes_split(tensor, nrows_split);
+        const size_t original_size = size;
+
+        // pad last row to a multiple of 256 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
+                * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
+        }
 
-        void * buf;
+        char * buf;
         CUDA_CHECK(cudaMalloc(&buf, size));
-        void * buf_host = (char*)data + offset_split;
+        char * buf_host = (char*)data + offset_split;
+
+        // set padding to 0 to avoid possible NaN values
+        if (size > original_size) {
+            CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
+        }
+
 
         cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
 
index b9e50ac745eb077f927a51fef6ab1af65de1e032..928f1705c381cf710468968b3f5a66e19e5a0c47 100644 (file)
@@ -34,9 +34,13 @@ extern "C" {
 
 struct ggml_metal_context;
 
-struct ggml_metal_context * ggml_metal_init(void);
+// number of command buffers to use
+struct ggml_metal_context * ggml_metal_init(int n_cb);
 void ggml_metal_free(struct ggml_metal_context * ctx);
 
+// set the number of command buffers to use
+void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
+
 // creates a mapping between a host memory buffer and a device memory buffer
 // - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
 // - the mapping is used during computation to determine the arguments of the compute kernels
index fd69c41fe357d6e1636f2640186c7908352450b9..6473644c24204e85d91ff80d92a52992de4e40c4 100644 (file)
@@ -25,6 +25,8 @@ struct ggml_metal_buffer {
 };
 
 struct ggml_metal_context {
+    int n_cb;
+
     float * logits;
 
     id<MTLDevice>       device;
@@ -86,11 +88,12 @@ static NSString * const msl_library_source = @"see metal.metal";
 @implementation GGMLMetalClass
 @end
 
-struct ggml_metal_context * ggml_metal_init(void) {
+struct ggml_metal_context * ggml_metal_init(int n_cb) {
     fprintf(stderr, "%s: allocating\n", __func__);
 
     struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
 
+    ctx->n_cb   = n_cb;
     ctx->device = MTLCreateSystemDefaultDevice();
     ctx->queue  = [ctx->device newCommandQueue];
     ctx->n_buffers = 0;
@@ -208,6 +211,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     free(ctx);
 }
 
+void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
+    ctx->n_cb = n_cb;
+}
+
 // finds the Metal buffer that contains the tensor data on the GPU device
 // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
 // Metal buffer based on the host memory pointer
@@ -354,7 +361,7 @@ void ggml_metal_graph_compute(
     // create multiple command buffers and enqueue them
     // then, we encode the graph into the command buffers in parallel
 
-    const int n_cb = gf->n_threads;
+    const int n_cb = ctx->n_cb;
 
     NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
 
@@ -443,6 +450,7 @@ void ggml_metal_graph_compute(
                 //}
 
                 switch (dst->op) {
+                    case GGML_OP_NONE:
                     case GGML_OP_RESHAPE:
                     case GGML_OP_VIEW:
                     case GGML_OP_TRANSPOSE:
index fa0bdbefb1de4fba5b64b39707dfca742236e12a..eb214a836489b04a4e3d6b9c6444ab30570edb89 100644 (file)
@@ -653,13 +653,17 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
     const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
     const int in = tid - step*im;                        // 0...15 or 0...7
 
-#if K_QUANTS_PER_ITERATION == 1
+\n#if K_QUANTS_PER_ITERATION == 1\n
     const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
     const int is = 0;
-#else
+
+\n#else\n
+
     const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
     const int is = in / 4;
-#endif
+
+\n#endif\n
+
     const int ql_offset = 64*im + l0;
     const int qh_offset = 32*im + l0;
     const int s_offset  =  8*im + is;
@@ -676,7 +680,7 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
 
         const float d = vload_half(0, &x[i].d);
 
-#if K_QUANTS_PER_ITERATION == 1
+\n#if K_QUANTS_PER_ITERATION == 1\n
         float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
                   + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
                   + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
@@ -686,7 +690,7 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
                   + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
                   +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
         tmp[16 * ix + tid] += sum;
-#else
+\n#else\n
         float sum = 0;
         for (int l = 0; l < 4; ++l) {
             sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
@@ -695,7 +699,7 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
                  + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
         }
         tmp[16 * ix + tid] += sum;
-#endif
+\n#endif\n
 
     }
 
index d257c3d657b34681cce41893456a64e7fa5dbdee..c10877a761dd821dd44b32b827e88863e4909fe4 100644 (file)
@@ -247,7 +247,11 @@ inline static void* ggml_aligned_malloc(size_t size) {
 #include "ggml-opencl.h"
 #endif
 #elif defined(GGML_USE_OPENBLAS)
+#if defined(GGML_BLAS_USE_MKL)
+#include <mkl.h>
+#else
 #include <cblas.h>
+#endif
 #elif defined(GGML_USE_CUBLAS)
 #include "ggml-cuda.h"
 #elif defined(GGML_USE_CLBLAST)
@@ -4583,14 +4587,13 @@ struct ggml_tensor * ggml_new_tensor_impl(
         /*.src0         =*/ NULL,
         /*.src1         =*/ NULL,
         /*.opt          =*/ { NULL },
-        /*.n_tasks      =*/ 0,
         /*.perf_runs    =*/ 0,
         /*.perf_cycles  =*/ 0,
         /*.perf_time_us =*/ 0,
         /*.data         =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
         /*.name         =*/ { 0 },
         /*.extra        =*/ NULL,
-        /*.pad          =*/ { 0 },
+        /*.padding      =*/ { 0 },
     };
 
     // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
@@ -10718,8 +10721,6 @@ static void ggml_compute_forward_mul_mat(
 
         float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
 
-        assert(ne00 % 32 == 0);
-
         for (int64_t ic = 0; ic < ne11; ++ic) {
             vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
         }
@@ -15772,9 +15773,6 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
     struct ggml_cgraph result = {
         /*.n_nodes      =*/ 0,
         /*.n_leafs      =*/ 0,
-        /*.n_threads    =*/ GGML_DEFAULT_N_THREADS,
-        /*.work_size    =*/ 0,
-        /*.work         =*/ NULL,
         /*.nodes        =*/ { NULL },
         /*.grads        =*/ { NULL },
         /*.leafs        =*/ { NULL },
@@ -15945,12 +15943,13 @@ void clear_numa_thread_affinity(void) {}
 #endif
 
 struct ggml_compute_state_shared {
-    struct ggml_cgraph * cgraph;
+    const struct ggml_cgraph * cgraph;
+    const struct ggml_cplan  * cplan;
 
     int64_t perf_node_start_cycles;
     int64_t perf_node_start_time_us;
 
-    int n_threads;
+    const int n_threads;
 
     // synchronization primitives
     atomic_int n_active; // num active threads
@@ -15974,9 +15973,13 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
 
 static thread_ret_t ggml_graph_compute_thread(void * data) {
     struct ggml_compute_state * state = (struct ggml_compute_state *) data;
-    struct ggml_cgraph * cgraph = state->shared->cgraph;
 
-    const int n_threads = state->shared->n_threads;
+    const struct ggml_cgraph * cgraph = state->shared->cgraph;
+    const struct ggml_cplan  * cplan  = state->shared->cplan;
+
+    const int * n_tasks_arr = cplan->n_tasks;
+    const int   n_threads   = state->shared->n_threads;
+
     set_numa_thread_affinity(state->ith, n_threads);
 
     int node_n = -1;
@@ -15989,15 +15992,15 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                 /*.type  =*/ GGML_TASK_FINALIZE,
                 /*.ith   =*/ 0,
                 /*.nth   =*/ 0,
-                /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-                /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+                /*.wsize =*/ cplan->work_size,
+                /*.wdata =*/ cplan->work_data,
             };
 
             if (node_n != -1) {
                 /* FINALIZE */
                 struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
                 if (GGML_OP_HAS_FINALIZE[node->op]) {
-                    params.nth = node->n_tasks;
+                    params.nth = n_tasks_arr[node_n];
                     ggml_compute_forward(&params, node);
                     ggml_graph_compute_perf_stats_node(node, state->shared);
                 }
@@ -16008,11 +16011,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                 GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
 
                 struct ggml_tensor * node = cgraph->nodes[node_n];
+                const int n_tasks = n_tasks_arr[node_n];
 
                 state->shared->perf_node_start_cycles  = ggml_perf_cycles();
                 state->shared->perf_node_start_time_us = ggml_perf_time_us();
 
-                params.nth = node->n_tasks;
+                params.nth = n_tasks;
 
                 /* INIT */
                 if (GGML_OP_HAS_INIT[node->op]) {
@@ -16020,7 +16024,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                     ggml_compute_forward(&params, node);
                 }
 
-                if (node->n_tasks == 1) {
+                if (n_tasks == 1) {
                     // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
                     // they do something more efficient than spinning (?)
                     params.type = GGML_TASK_COMPUTE;
@@ -16042,7 +16046,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
             // wait for other threads to finish
             const int last = node_n;
             do {
-                sched_yield();
+                //sched_yield();
                 node_n = atomic_load(&state->shared->node_n);
             } while (node_n == last);
         }
@@ -16052,16 +16056,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
 
         /* COMPUTE */
         struct ggml_tensor * node = cgraph->nodes[node_n];
+        const int n_tasks = n_tasks_arr[node_n];
 
         struct ggml_compute_params params = {
             /*.type  =*/ GGML_TASK_COMPUTE,
             /*.ith   =*/ state->ith,
-            /*.nth   =*/ node->n_tasks,
-            /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-            /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+            /*.nth   =*/ n_tasks,
+            /*.wsize =*/ cplan->work_size,
+            /*.wdata =*/ cplan->work_data,
         };
 
-        if (state->ith < node->n_tasks) {
+        if (state->ith < n_tasks) {
             ggml_compute_forward(&params, node);
         }
     }
@@ -16069,349 +16074,372 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
     return 0;
 }
 
-void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
-    const int n_threads = cgraph->n_threads;
+struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
+    if (n_threads <= 0) {
+        n_threads = GGML_DEFAULT_N_THREADS;
+    }
 
-    struct ggml_compute_state_shared state_shared = {
-        /*.cgraph                  =*/ cgraph,
-        /*.perf_node_start_cycles  =*/ 0,
-        /*.perf_node_start_time_us =*/ 0,
-        /*.n_threads               =*/ n_threads,
-        /*.n_active                =*/ n_threads,
-        /*.node_n                  =*/ -1,
-    };
-    struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
+    size_t work_size = 0;
 
-    // initialize tasks + work buffer
-    {
-        size_t work_size = 0;
+    struct ggml_cplan cplan;
+    memset(&cplan, 0, sizeof(struct ggml_cplan));
 
-        // thread scheduling for the different operations
-        for (int i = 0; i < cgraph->n_nodes; i++) {
-            struct ggml_tensor * node = cgraph->nodes[i];
+    // thread scheduling for the different operations + work buffer size estimation
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        int n_tasks = 1;
 
-            switch (node->op) {
-                case GGML_OP_CPY:
-                case GGML_OP_DUP:
-                    {
-                        node->n_tasks = n_threads;
+        struct ggml_tensor * node = cgraph->nodes[i];
 
-                        size_t cur = 0;
-                        if (ggml_is_quantized(node->type)) {
-                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
-                        }
+        switch (node->op) {
+            case GGML_OP_CPY:
+            case GGML_OP_DUP:
+                {
+                    n_tasks = n_threads;
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_ADD:
-                case GGML_OP_ADD1:
-                    {
-                        node->n_tasks = n_threads;
+                    size_t cur = 0;
+                    if (ggml_is_quantized(node->type)) {
+                        cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
+                    }
 
-                        size_t cur = 0;
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_ADD:
+            case GGML_OP_ADD1:
+                {
+                    n_tasks = n_threads;
 
-                        if (ggml_is_quantized(node->src0->type)) {
-                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
-                        }
+                    size_t cur = 0;
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_ACC:
-                    {
-                        node->n_tasks = n_threads;
+                    if (ggml_is_quantized(node->src0->type)) {
+                        cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_tasks;
+                    }
+
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_ACC:
+                {
+                    n_tasks = n_threads;
 
-                        size_t cur = 0;
+                    size_t cur = 0;
 
-                        if (ggml_is_quantized(node->src0->type)) {
-                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_threads;
-                        }
+                    if (ggml_is_quantized(node->src0->type)) {
+                        cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_tasks;
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_SUB:
-                case GGML_OP_DIV:
-                case GGML_OP_SQR:
-                case GGML_OP_SQRT:
-                case GGML_OP_LOG:
-                case GGML_OP_SUM:
-                case GGML_OP_SUM_ROWS:
-                case GGML_OP_MEAN:
-                case GGML_OP_ARGMAX:
-                case GGML_OP_REPEAT:
-                case GGML_OP_REPEAT_BACK:
-                case GGML_OP_ABS:
-                case GGML_OP_SGN:
-                case GGML_OP_NEG:
-                case GGML_OP_STEP:
-                case GGML_OP_TANH:
-                case GGML_OP_ELU:
-                case GGML_OP_RELU:
-                    {
-                        node->n_tasks = 1;
-                    } break;
-                case GGML_OP_MUL:
-                case GGML_OP_GELU:
-                case GGML_OP_GELU_QUICK:
-                case GGML_OP_SILU:
-                case GGML_OP_SILU_BACK:
-                case GGML_OP_NORM:
-                case GGML_OP_RMS_NORM:
-                case GGML_OP_RMS_NORM_BACK:
-                    {
-                        node->n_tasks = n_threads;
-                    } break;
-                case GGML_OP_MUL_MAT:
-                case GGML_OP_OUT_PROD:
-                    {
-                        node->n_tasks = n_threads;
-
-                        // TODO: use different scheduling for different matrix sizes
-                        //const int nr0 = ggml_nrows(node->src0);
-                        //const int nr1 = ggml_nrows(node->src1);
-
-                        //node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
-                        //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
-
-                        size_t cur = 0;
-                        const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_SUB:
+            case GGML_OP_DIV:
+            case GGML_OP_SQR:
+            case GGML_OP_SQRT:
+            case GGML_OP_LOG:
+            case GGML_OP_SUM:
+            case GGML_OP_SUM_ROWS:
+            case GGML_OP_MEAN:
+            case GGML_OP_ARGMAX:
+            case GGML_OP_REPEAT:
+            case GGML_OP_REPEAT_BACK:
+            case GGML_OP_ABS:
+            case GGML_OP_SGN:
+            case GGML_OP_NEG:
+            case GGML_OP_STEP:
+            case GGML_OP_TANH:
+            case GGML_OP_ELU:
+            case GGML_OP_RELU:
+                {
+                    n_tasks = 1;
+                } break;
+            case GGML_OP_MUL:
+            case GGML_OP_GELU:
+            case GGML_OP_GELU_QUICK:
+            case GGML_OP_SILU:
+            case GGML_OP_SILU_BACK:
+            case GGML_OP_NORM:
+            case GGML_OP_RMS_NORM:
+            case GGML_OP_RMS_NORM_BACK:
+                {
+                    n_tasks = n_threads;
+                } break;
+            case GGML_OP_MUL_MAT:
+            case GGML_OP_OUT_PROD:
+                {
+                    n_tasks = n_threads;
+
+                    // TODO: use different scheduling for different matrix sizes
+                    //const int nr0 = ggml_nrows(node->src0);
+                    //const int nr1 = ggml_nrows(node->src1);
+
+                    //n_tasks = MIN(n_threads, MAX(1, nr0/128));
+                    //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
+
+                    size_t cur = 0;
+                    const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
 
 #if defined(GGML_USE_CUBLAS)
-                        if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
-                            node->n_tasks = 1; // TODO: this actually is doing nothing
-                                                //       the threads are still spinning
-                        }
-                        else
+                    if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
+                        n_tasks = 1; // TODO: this actually is doing nothing
+                                     //       the threads are still spinning
+                    } else
 #elif defined(GGML_USE_CLBLAST)
-                        if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
-                            node->n_tasks = 1; // TODO: this actually is doing nothing
-                                                //       the threads are still spinning
-                            cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
-                        }
-                        else
+                    if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
+                        n_tasks = 1; // TODO: this actually is doing nothing
+                                     //       the threads are still spinning
+                        cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
+                    } else
 #endif
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                        if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                            node->n_tasks = 1; // TODO: this actually is doing nothing
-                                               //       the threads are still spinning
-                            if (node->src0->type != GGML_TYPE_F32) {
-                                // here we need memory just for single 2D matrix from src0
-                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-                            }
-                        } else
-#endif
-                        if (node->src1->type != vec_dot_type) {
-                            cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
-                        } else {
-                            cur = 0;
+                    if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                        n_tasks = 1; // TODO: this actually is doing nothing
+                                     //       the threads are still spinning
+                        if (node->src0->type != GGML_TYPE_F32) {
+                            // here we need memory just for single 2D matrix from src0
+                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
                         }
+                    } else
+#endif
+                    if (node->src1->type != vec_dot_type) {
+                        cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
+                    } else {
+                        cur = 0;
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_SCALE:
-                    {
-                        node->n_tasks = 1;
-                    } break;
-                case GGML_OP_SET:
-                case GGML_OP_CONT:
-                case GGML_OP_RESHAPE:
-                case GGML_OP_VIEW:
-                case GGML_OP_PERMUTE:
-                case GGML_OP_TRANSPOSE:
-                case GGML_OP_GET_ROWS:
-                case GGML_OP_GET_ROWS_BACK:
-                case GGML_OP_DIAG:
-                case GGML_OP_DIAG_MASK_ZERO:
-                    {
-                        node->n_tasks = 1;
-                    } break;
-                case GGML_OP_DIAG_MASK_INF:
-                case GGML_OP_SOFT_MAX:
-                case GGML_OP_SOFT_MAX_BACK:
-                case GGML_OP_ROPE:
-                case GGML_OP_ROPE_BACK:
-                    {
-                        node->n_tasks = n_threads;
-                    } break;
-                case GGML_OP_ALIBI:
-                    {
-                        node->n_tasks = 1; //TODO
-                    } break;
-                case GGML_OP_CLAMP:
-                    {
-                        node->n_tasks = 1; //TODO
-                    } break;
-                case GGML_OP_CONV_1D:
-                    {
-                        node->n_tasks = n_threads;
-
-                        GGML_ASSERT(node->src0->ne[3] == 1);
-                        GGML_ASSERT(node->src1->ne[2] == 1);
-                        GGML_ASSERT(node->src1->ne[3] == 1);
-
-                        size_t cur = 0;
-                        const int nk = node->src0->ne[0];
-
-                        if (node->src0->type == GGML_TYPE_F16 &&
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_SCALE:
+                {
+                    n_tasks = 1;
+                } break;
+            case GGML_OP_SET:
+            case GGML_OP_CONT:
+            case GGML_OP_RESHAPE:
+            case GGML_OP_VIEW:
+            case GGML_OP_PERMUTE:
+            case GGML_OP_TRANSPOSE:
+            case GGML_OP_GET_ROWS:
+            case GGML_OP_GET_ROWS_BACK:
+            case GGML_OP_DIAG:
+            case GGML_OP_DIAG_MASK_ZERO:
+                {
+                    n_tasks = 1;
+                } break;
+            case GGML_OP_DIAG_MASK_INF:
+            case GGML_OP_SOFT_MAX:
+            case GGML_OP_SOFT_MAX_BACK:
+            case GGML_OP_ROPE:
+            case GGML_OP_ROPE_BACK:
+                {
+                    n_tasks = n_threads;
+                } break;
+            case GGML_OP_ALIBI:
+                {
+                    n_tasks = 1; //TODO
+                } break;
+            case GGML_OP_CLAMP:
+                {
+                    n_tasks = 1; //TODO
+                } break;
+            case GGML_OP_CONV_1D:
+                {
+                    n_tasks = n_threads;
+
+                    GGML_ASSERT(node->src0->ne[3] == 1);
+                    GGML_ASSERT(node->src1->ne[2] == 1);
+                    GGML_ASSERT(node->src1->ne[3] == 1);
+
+                    size_t cur = 0;
+                    const int nk = node->src0->ne[0];
+
+                    if (node->src0->type == GGML_TYPE_F16 &&
                             node->src1->type == GGML_TYPE_F32) {
-                            cur = sizeof(ggml_fp16_t)*(
-                                    nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
-                                    ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
-                                    );
-                        } else if (node->src0->type == GGML_TYPE_F32 &&
-                                   node->src1->type == GGML_TYPE_F32) {
-                            cur = sizeof(float)*(
-                                    nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
-                                    ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
-                                    );
-                        } else {
-                            GGML_ASSERT(false);
-                        }
+                        cur = sizeof(ggml_fp16_t)*(
+                                nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
+                                ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+                                );
+                    } else if (node->src0->type == GGML_TYPE_F32 &&
+                            node->src1->type == GGML_TYPE_F32) {
+                        cur = sizeof(float)*(
+                                nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
+                                ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+                                );
+                    } else {
+                        GGML_ASSERT(false);
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_CONV_2D:
-                    {
-                        node->n_tasks = n_threads;
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_CONV_2D:
+                {
+                    n_tasks = n_threads;
 
-                        GGML_ASSERT(node->src1->ne[3] == 1);
+                    GGML_ASSERT(node->src1->ne[3] == 1);
 
-                        const int64_t ne00 = node->src0->ne[0]; // W
-                        const int64_t ne01 = node->src0->ne[1]; // H
-                        const int64_t ne02 = node->src0->ne[2]; // C
-                        const int64_t ne03 = node->src0->ne[3]; // N
+                    const int64_t ne00 = node->src0->ne[0]; // W
+                    const int64_t ne01 = node->src0->ne[1]; // H
+                    const int64_t ne02 = node->src0->ne[2]; // C
+                    const int64_t ne03 = node->src0->ne[3]; // N
 
-                        const int64_t ne10 = node->src1->ne[0]; // W
-                        const int64_t ne11 = node->src1->ne[1]; // H
-                        const int64_t ne12 = node->src1->ne[2]; // C
+                    const int64_t ne10 = node->src1->ne[0]; // W
+                    const int64_t ne11 = node->src1->ne[1]; // H
+                    const int64_t ne12 = node->src1->ne[2]; // C
 
-                        const int64_t nk = ne00*ne01;
+                    const int64_t nk = ne00*ne01;
 
-                        UNUSED(ne02);
-                        UNUSED(ne03);
-                        UNUSED(nk);
+                    UNUSED(ne02);
+                    UNUSED(ne03);
+                    UNUSED(nk);
 
-                        size_t cur = 0;
+                    size_t cur = 0;
 
-                        if (node->src0->type == GGML_TYPE_F16 &&
+                    if (node->src0->type == GGML_TYPE_F16 &&
                             node->src1->type == GGML_TYPE_F32) {
-                            cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
-                        } else if (node->src0->type == GGML_TYPE_F32 &&
-                                   node->src1->type == GGML_TYPE_F32) {
-                            cur = sizeof(float)*      (ne10*ne11*ne12);
-                        } else {
-                            GGML_ASSERT(false);
-                        }
+                        cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
+                    } else if (node->src0->type == GGML_TYPE_F32 &&
+                            node->src1->type == GGML_TYPE_F32) {
+                        cur = sizeof(float)*      (ne10*ne11*ne12);
+                    } else {
+                        GGML_ASSERT(false);
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_FLASH_ATTN:
-                    {
-                        node->n_tasks = n_threads;
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_FLASH_ATTN:
+                {
+                    n_tasks = n_threads;
 
-                        size_t cur = 0;
+                    size_t cur = 0;
 
-                        const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+                    const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
 
-                        if (node->src1->type == GGML_TYPE_F32) {
-                            cur  = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
-                        }
+                    if (node->src1->type == GGML_TYPE_F32) {
+                        cur  = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
+                    }
 
-                        if (node->src1->type == GGML_TYPE_F16) {
-                            cur  = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
-                        }
+                    if (node->src1->type == GGML_TYPE_F16) {
+                        cur  = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_FLASH_FF:
-                    {
-                        node->n_tasks = n_threads;
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_FLASH_FF:
+                {
+                    n_tasks = n_threads;
 
-                        size_t cur = 0;
+                    size_t cur = 0;
 
-                        if (node->src1->type == GGML_TYPE_F32) {
-                            cur  = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
-                        }
+                    if (node->src1->type == GGML_TYPE_F32) {
+                        cur  = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
+                    }
 
-                        if (node->src1->type == GGML_TYPE_F16) {
-                            cur  = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
-                        }
+                    if (node->src1->type == GGML_TYPE_F16) {
+                        cur  = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_FLASH_ATTN_BACK:
-                    {
-                        node->n_tasks = n_threads;
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_FLASH_ATTN_BACK:
+                {
+                    n_tasks = n_threads;
 
-                        size_t cur = 0;
+                    size_t cur = 0;
 
-                        const int64_t    D = node->src0->ne[0];
-                        const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
-                        const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
-                        if (node->src1->type == GGML_TYPE_F32) {
-                            cur  = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
-                        }
+                    const int64_t    D = node->src0->ne[0];
+                    const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+                    const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
+                    if (node->src1->type == GGML_TYPE_F32) {
+                        cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+                    }
 
-                        if (node->src1->type == GGML_TYPE_F16) {
-                            cur  = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
-                        }
+                    if (node->src1->type == GGML_TYPE_F16) {
+                        cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+                        cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+                    }
 
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_WIN_PART:
-                case GGML_OP_WIN_UNPART:
-                case GGML_OP_MAP_UNARY:
-                case GGML_OP_MAP_BINARY:
-                case GGML_OP_MAP_CUSTOM1:
-                case GGML_OP_MAP_CUSTOM2:
-                case GGML_OP_MAP_CUSTOM3:
-                    {
-                        node->n_tasks = 1;
-                    } break;
-                case GGML_OP_CROSS_ENTROPY_LOSS:
-                    {
-                        node->n_tasks = n_threads;
-
-                        size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
-
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
-                    {
-                        node->n_tasks = n_threads;
-
-                        size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
-
-                        work_size = MAX(work_size, cur);
-                    } break;
-                case GGML_OP_NONE:
-                    {
-                        node->n_tasks = 1;
-                    } break;
-                case GGML_OP_COUNT:
-                    {
-                        GGML_ASSERT(false);
-                    } break;
-            }
-        }
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_WIN_PART:
+            case GGML_OP_WIN_UNPART:
+            case GGML_OP_MAP_UNARY:
+            case GGML_OP_MAP_BINARY:
+            case GGML_OP_MAP_CUSTOM1:
+            case GGML_OP_MAP_CUSTOM2:
+            case GGML_OP_MAP_CUSTOM3:
+                {
+                    n_tasks = 1;
+                } break;
+            case GGML_OP_CROSS_ENTROPY_LOSS:
+                {
+                    n_tasks = n_threads;
+
+                    size_t cur = ggml_type_size(node->type)*(n_tasks + node->src0->ne[0]*n_tasks);
+
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+                {
+                    n_tasks = n_threads;
+
+                    size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*n_tasks;
 
-        if (cgraph->work != NULL && work_size > cgraph->work_size) {
-            GGML_ASSERT(false); // TODO: better handling
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_NONE:
+                {
+                    n_tasks = 1;
+                } break;
+            case GGML_OP_COUNT:
+                {
+                    GGML_ASSERT(false);
+                } break;
         }
 
-        if (work_size > 0 && cgraph->work == NULL) {
-            cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
+        cplan.n_tasks[i] = n_tasks;
+    }
+
+    if (work_size > 0) {
+        work_size += CACHE_LINE_SIZE*(n_threads - 1);
+    }
 
-            GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
-            cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
+    cplan.n_threads = n_threads;
+    cplan.work_size = work_size;
+    cplan.work_data = NULL;
+
+    return cplan;
+}
+
+void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
+    {
+        GGML_ASSERT(cplan);
+        GGML_ASSERT(cplan->n_threads > 0);
+
+        if (cplan->work_size > 0) {
+            GGML_ASSERT(cplan->work_data);
+        }
+
+        for (int i = 0; i < cgraph->n_nodes; ++i) {
+            if (cgraph->nodes[i]->op != GGML_OP_NONE) {
+                GGML_ASSERT(cplan->n_tasks[i] > 0);
+            }
         }
     }
 
+    const int n_threads = cplan->n_threads;
+
+    struct ggml_compute_state_shared state_shared = {
+        /*.cgraph                  =*/ cgraph,
+        /*.cgraph_plan             =*/ cplan,
+        /*.perf_node_start_cycles  =*/ 0,
+        /*.perf_node_start_time_us =*/ 0,
+        /*.n_threads               =*/ n_threads,
+        /*.n_active                =*/ n_threads,
+        /*.node_n                  =*/ -1,
+    };
+    struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
+
     // create thread pool
     if (n_threads > 1) {
         for (int j = 1; j < n_threads; ++j) {
@@ -16473,6 +16501,17 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
     }
 }
 
+void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
+    struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
+
+    struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
+    GGML_ASSERT(buf);
+
+    cplan.work_data = buf->data;
+
+    ggml_graph_compute(cgraph, &cplan);
+}
+
 struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
     for (int i = 0; i < cgraph->n_leafs; i++) {
         struct ggml_tensor * leaf = cgraph->leafs[i];
@@ -16511,14 +16550,13 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
     const int64_t * ne = tensor->ne;
     const size_t  * nb = tensor->nb;
 
-    fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
+    fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
             arg,
             ggml_type_name(tensor->type),
             ggml_op_name  (tensor->op),
             tensor->n_dims,
             ne[0], ne[1], ne[2], ne[3],
             nb[0], nb[1], nb[2], nb[3],
-            tensor->n_tasks,
             tensor->data,
             tensor->name);
 }
@@ -17254,9 +17292,6 @@ static enum ggml_opt_result ggml_opt_adam(
         struct ggml_cgraph * gb) {
     GGML_ASSERT(ggml_is_scalar(f));
 
-    gf->n_threads = params.n_threads;
-    gb->n_threads = params.n_threads;
-
     // these will store the parameters we want to optimize
     struct ggml_tensor * ps[GGML_MAX_PARAMS];
 
@@ -17303,7 +17338,8 @@ static enum ggml_opt_result ggml_opt_adam(
     // compute the function value
     ggml_graph_reset  (gf);
     ggml_set_f32      (f->grad, 1.0f);
-    ggml_graph_compute(ctx, gb);
+
+    ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
 
     opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
     opt->adam.fx_best = opt->adam.fx_prev;
@@ -17383,7 +17419,8 @@ static enum ggml_opt_result ggml_opt_adam(
 
         ggml_graph_reset  (gf);
         ggml_set_f32      (f->grad, 1.0f);
-        ggml_graph_compute(ctx, gb);
+
+        ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
 
         const float fx = ggml_get_f32_1d(f, 0);
 
@@ -17505,7 +17542,8 @@ static enum ggml_opt_result linesearch_backtracking(
 
             ggml_graph_reset  (gf);
             ggml_set_f32      (f->grad, 1.0f);
-            ggml_graph_compute(ctx, gb);
+
+            ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
 
             ggml_opt_get_grad(np, ps, g);
 
@@ -17573,9 +17611,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         }
     }
 
-    gf->n_threads = params.n_threads;
-    gb->n_threads = params.n_threads;
-
     const int m = params.lbfgs.m;
 
     // these will store the parameters we want to optimize
@@ -17627,7 +17662,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
 
         ggml_graph_reset  (gf);
         ggml_set_f32      (f->grad, 1.0f);
-        ggml_graph_compute(ctx, gb);
+
+        ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
 
         ggml_opt_get_grad(np, ps, g);
 
index aecccf02e3346d856512f51f8e4ea9f03eeef64e..0977d3ef83770746352820753f8abcc501c81101 100644 (file)
@@ -46,6 +46,8 @@ int main(int argc, const char ** argv) {
         return 1;
     }
 
+    const int n_threads = 1;
+
     int M = atoi(argv[1]);
     int N = atoi(argv[2]);
     int K = atoi(argv[3]);
@@ -131,14 +133,14 @@ int main(int argc, const char ** argv) {
         dst2 = ggml_mul_mat(ctx0, s0_f32, s1_f32);
 
         struct ggml_cgraph gf = ggml_build_forward(dst2);
-        ggml_graph_compute(ctx0, &gf);
+        ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
     }
 
     {
         dst3 = ggml_mul_mat(ctx0, s0_f16, s1_f32);
 
         struct ggml_cgraph gf = ggml_build_forward(dst3);
-        ggml_graph_compute(ctx0, &gf);
+        ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
     }
 
     bool ok_blas = true;
index a3e25214b84eb29766ff2a879d7d40edcdaf3fc1..da4001ce5269fbac5c9fa720396954f996662d9b 100644 (file)
@@ -10,6 +10,8 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
+#pragma GCC diagnostic ignored "-Wdouble-promotion"
+
 #define MAX_NARGS 3
 
 #undef MIN
@@ -49,7 +51,7 @@ float frand(void) {
 
 int irand(int n) {
     if (n == 0) return 0;
-    else return rand()%n;
+    return rand()%n;
 }
 
 void get_random_dims(int64_t * dims, int ndims) {
@@ -159,12 +161,14 @@ struct ggml_tensor * get_random_tensor_int(
 float get_element(const struct ggml_tensor * t, int idx) {
     if (t->type == GGML_TYPE_F32) {
         return ((float *)t->data)[idx];
-    } else if (t->type == GGML_TYPE_I32) {
+    }
+
+    if (t->type == GGML_TYPE_I32) {
         return ((int32_t *)t->data)[idx];
-    } else {
-        assert(false);
-        return INFINITY;
     }
+
+    assert(false);
+    return INFINITY;
 }
 
 void set_element(struct ggml_tensor * t, int idx, float value) {
@@ -215,15 +219,14 @@ bool check_gradient(
     }
 
     struct ggml_cgraph gf = ggml_build_forward (f);
-    gf.n_threads = n_threads;
-
     struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
-    gb.n_threads = n_threads;
 
-    ggml_graph_compute(ctx0, &gf);
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+
     ggml_graph_reset  (&gf);
     ggml_set_f32      (f->grad, 1.0f);
-    ggml_graph_compute(ctx0, &gb);
+
+    ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
     // ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
     // ggml_graph_dump_dot(&gb, &gf,  "test-grad0-backward.dot");
@@ -236,15 +239,16 @@ bool check_gradient(
             const float xm = x0 - eps;
             const float xp = x0 + eps;
             set_element(x[i], k, xp);
-            ggml_graph_compute(ctx0, &gf);
+
+            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
             const float f0 = ggml_get_f32_1d(f, 0);
 
             set_element(x[i], k, xm);
-            ggml_graph_compute(ctx0, &gf);
 
-            const float f1 = ggml_get_f32_1d(f, 0);
+            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
+            const float f1 = ggml_get_f32_1d(f, 0);
             const float g0 = (f0 - f1)/(2.0f*eps);
 
             set_element(x[i], k, x0);
@@ -252,12 +256,13 @@ bool check_gradient(
             // compute gradient using backward graph
             ggml_graph_reset  (&gf);
             ggml_set_f32      (f->grad, 1.0f);
-            ggml_graph_compute(ctx0, &gb);
+
+            ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
             const float g1 = get_element(x[i]->grad, k);
 
             const float error_abs = fabsf(g0 - g1);
-            const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabs(g0) : 0;
+            const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
 
             if (error_abs > max_error_abs || error_rel > max_error_rel) {
                 printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
index 185df39651fd82432b1c3cec0ff9a072b434096d..1bd6e140ba86b407fc5edf898d9f0f0593d12fb2 100644 (file)
@@ -95,14 +95,15 @@ bool check_gradient(
         float eps,
         float max_error_abs,
         float max_error_rel) {
+    const int n_threads = 1;
 
     struct ggml_cgraph gf = ggml_build_forward (f);
     struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
 
-    ggml_graph_compute(ctx0, &gf);
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
     ggml_graph_reset  (&gf);
     ggml_set_f32      (f->grad, 1.0f);
-    ggml_graph_compute(ctx0, &gb);
+    ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
     ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
     ggml_graph_dump_dot(&gb, &gf,  "test-grad0-backward.dot");
@@ -114,12 +115,12 @@ bool check_gradient(
             const float x0 = get_element(x[i], k);
 
             set_element(x[i], k, x0 + eps);
-            ggml_graph_compute(ctx0, &gf);
+            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
             const float f0 = ggml_get_f32_1d(f, 0);
 
             set_element(x[i], k, x0 - eps);
-            ggml_graph_compute(ctx0, &gf);
+            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
             const float f1 = ggml_get_f32_1d(f, 0);
 
@@ -130,7 +131,7 @@ bool check_gradient(
             // compute gradient using backward graph
             ggml_graph_reset  (&gf);
             ggml_set_f32      (f->grad, 1.0f);
-            ggml_graph_compute(ctx0, &gb);
+            ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
             const float g1 = get_element(x[i]->grad, k);
 
@@ -247,6 +248,9 @@ int main(int argc, const char ** argv) {
     if (argc > 1) {
         niter = atoi(argv[1]);
     }
+
+    int n_threads = 1;
+
     for (int iter = 0; iter < niter; ++iter) {
         printf("test-mul-mat0: iter:%d/%d\n", iter, niter);
         struct ggml_context * ctx0 = ggml_init(params);
@@ -283,7 +287,7 @@ int main(int argc, const char ** argv) {
                     check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
                 } else {
                     struct ggml_cgraph gf = ggml_build_forward(m);
-                    ggml_graph_compute(ctx0, &gf);
+                    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
                 }
 
                 check_mat_mul(m, x[1], x[0]);
@@ -319,7 +323,7 @@ int main(int argc, const char ** argv) {
                     check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
                 } else {
                     struct ggml_cgraph gf = ggml_build_forward(m);
-                    ggml_graph_compute(ctx0, &gf);
+                    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
                 }
 
                 check_mat_mul(m, x[1], x[0]);
index d001615ee353bfca5a02d3648f986ed636fb5fc7..e928a7df7ee68e60a9af80c46deb785d0f2a7383 100644 (file)
@@ -7,6 +7,7 @@
 
 #define MAX_NARGS 2
 
+#pragma GCC diagnostic ignored "-Wdouble-promotion"
 
 //
 // logging
@@ -33,7 +34,7 @@
 #define GGML_PRINT(...) printf(__VA_ARGS__)
 
 
-float frand() {
+float frand(void) {
     return (float)rand()/(float)RAND_MAX;
 }
 
@@ -114,7 +115,7 @@ void set_element(struct ggml_tensor * t, int idx, float value) {
     ((float *)t->data)[idx] = value;
 }
 
-int main(int argc, const char ** argv) {
+int main(void) {
     struct ggml_init_params params = {
         .mem_size   = 1024*1024*1024,
         .mem_buffer = NULL,
@@ -137,10 +138,11 @@ int main(int argc, const char ** argv) {
     struct ggml_tensor * d  = ggml_sub(ctx, c, ab);
     struct ggml_tensor * e  = ggml_sum(ctx, ggml_sqr(ctx, d));
 
-
     struct ggml_cgraph ge = ggml_build_forward(e);
-    ggml_graph_reset  (&ge);
-    ggml_graph_compute(ctx, &ge);
+    ggml_graph_reset(&ge);
+
+    ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
+
     const float fe = ggml_get_f32_1d(e, 0);
     printf("%s: e = %.4f\n", __func__, fe);
 
@@ -148,8 +150,10 @@ int main(int argc, const char ** argv) {
 
     ggml_opt(ctx, opt_params, e);
 
-    ggml_graph_reset  (&ge);
-    ggml_graph_compute(ctx, &ge);
+    ggml_graph_reset(&ge);
+
+    ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
+
     const float fe_opt = ggml_get_f32_1d(e, 0);
     printf("%s: original  e = %.4f\n", __func__, fe);
     printf("%s: optimized e = %.4f\n", __func__, fe_opt);
index 8c1a352e29333a18f70cc0b9f304a130c6bc55da..c313bf8e1b662a960cab28f1f397a90b9972943b 100644 (file)
@@ -4,6 +4,8 @@
 #include <stdlib.h>
 
 int main(int argc, const char ** argv) {
+    const int n_threads = 2;
+
     struct ggml_init_params params = {
         .mem_size   = 128*1024*1024,
         .mem_buffer = NULL,
@@ -35,7 +37,7 @@ int main(int argc, const char ** argv) {
         ggml_graph_reset(&gf);
         ggml_set_f32(f->grad, 1.0f);
 
-        ggml_graph_compute(ctx0, &gb);
+        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
         printf("f     = %f\n", ggml_get_f32_1d(f, 0));
         printf("df/dx = %f\n", ggml_get_f32_1d(x->grad, 0));
@@ -48,7 +50,7 @@ int main(int argc, const char ** argv) {
         ggml_graph_reset(&gf);
         ggml_set_f32(f->grad, 1.0f);
 
-        ggml_graph_compute(ctx0, &gb);
+        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
         printf("f     = %f\n", ggml_get_f32_1d(f, 0));
         printf("df/dx = %f\n", ggml_get_f32_1d(x->grad, 0));
@@ -82,7 +84,7 @@ int main(int argc, const char ** argv) {
         ggml_graph_reset(&gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute(ctx0, &gb);
+        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
@@ -101,7 +103,7 @@ int main(int argc, const char ** argv) {
         ggml_set_f32(g1->grad, 1.0f);
         ggml_set_f32(g2->grad, 1.0f);
 
-        ggml_graph_compute(ctx0, &gbb);
+        ggml_graph_compute_with_ctx(ctx0, &gbb, n_threads);
 
         printf("H * [1, 1] = [ %f %f ]\n", ggml_get_f32_1d(x1->grad, 0), ggml_get_f32_1d(x2->grad, 0));
 
@@ -132,7 +134,7 @@ int main(int argc, const char ** argv) {
         ggml_graph_reset(&gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute(ctx0, &gb);
+        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
@@ -169,7 +171,7 @@ int main(int argc, const char ** argv) {
         ggml_graph_reset(&gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute(ctx0, &gb);
+        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
@@ -192,7 +194,7 @@ int main(int argc, const char ** argv) {
         ggml_set_f32(g2->grad, 1.0f);
         ggml_set_f32(g3->grad, 1.0f);
 
-        ggml_graph_compute(ctx0, &gbb);
+        ggml_graph_compute_with_ctx(ctx0, &gbb, n_threads);
 
         printf("H * [1, 1, 1] = [ %f %f %f ]\n",
                 ggml_get_f32_1d(x1->grad, 0),
@@ -227,7 +229,7 @@ int main(int argc, const char ** argv) {
         ggml_graph_reset(&gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute(ctx0, &gb);
+        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f %f %f\n",
@@ -280,7 +282,7 @@ int main(int argc, const char ** argv) {
         ggml_graph_reset(&gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute(ctx0, &gb);
+        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f %f %f\n",
@@ -333,7 +335,7 @@ int main(int argc, const char ** argv) {
         ggml_graph_reset(&gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute(ctx0, &gb);
+        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f %f %f\n",
@@ -380,7 +382,7 @@ int main(int argc, const char ** argv) {
         ggml_graph_reset(&gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute(ctx0, &gb);
+        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f %f %f\n",
@@ -406,7 +408,7 @@ int main(int argc, const char ** argv) {
         ggml_graph_reset(&gf);
         ggml_set_f32(y->grad, 1.0f);
 
-        ggml_graph_compute(ctx0, &gb);
+        ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
         printf("y      = %f\n", ggml_get_f32_1d(y, 0));
         printf("df/dx1 = %f %f %f\n",
index 60fed53d194c4fb0a821bb0e193da67fa14b1c0f..f331acbd81b61d3babad7b0cccb45e795f2b3d9b 100644 (file)
@@ -4,6 +4,8 @@ const c = @cImport({
 });\r
 \r
 pub fn main() !void {\r
+    const n_threads = 2;\r
+\r
     const params = .{\r
         .mem_size   = 128*1024*1024,\r
         .mem_buffer = null,\r
@@ -36,7 +38,7 @@ pub fn main() !void {
         c.ggml_graph_reset(@constCast(&gf));\r
         _ = c.ggml_set_f32(f.*.grad, 1.0);\r
 \r
-        c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+        c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
 \r
         std.debug.print("f     = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)});\r
         std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)});\r
@@ -49,7 +51,7 @@ pub fn main() !void {
         c.ggml_graph_reset(@constCast(&gf));\r
         _ = c.ggml_set_f32(f.*.grad, 1.0);\r
 \r
-        c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+        c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
 \r
         std.debug.print("f     = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)});\r
         std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)});\r
@@ -83,7 +85,7 @@ pub fn main() !void {
         c.ggml_graph_reset(@constCast(&gf));\r
         _ = c.ggml_set_f32(y.*.grad, 1.0);\r
 \r
-        c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+        c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
 \r
         std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
         std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});\r
@@ -102,7 +104,7 @@ pub fn main() !void {
         _ = c.ggml_set_f32(g1.*.grad, 1.0);\r
         _ = c.ggml_set_f32(g2.*.grad, 1.0);\r
 \r
-        c.ggml_graph_compute(ctx0, @constCast(&gbb));\r
+        c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gbb), n_threads);\r
 \r
         std.debug.print("H * [1, 1] = [ {d:.6} {d:.6} ]\n", .{c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 0)});\r
 \r
@@ -133,7 +135,7 @@ pub fn main() !void {
         c.ggml_graph_reset(@constCast(&gf));\r
         _ = c.ggml_set_f32(y.*.grad, 1.0);\r
 \r
-        c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+        c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
 \r
         std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
         std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});\r
@@ -170,7 +172,7 @@ pub fn main() !void {
         c.ggml_graph_reset(@constCast(&gf));\r
         _ = c.ggml_set_f32(y.*.grad, 1.0);\r
 \r
-        c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+        c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
 \r
         std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
         std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});\r
@@ -193,7 +195,7 @@ pub fn main() !void {
         _ = c.ggml_set_f32(g2.*.grad, 1.0);\r
         _ = c.ggml_set_f32(g3.*.grad, 1.0);\r
 \r
-        c.ggml_graph_compute(ctx0, @constCast(&gbb));\r
+        c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gbb), n_threads);\r
 \r
         std.debug.print("H * [1, 1, 1] = [ {d:.6} {d:.6} {d:.6}]\n",\r
             .{\r
@@ -230,7 +232,7 @@ pub fn main() !void {
         c.ggml_graph_reset(@constCast(&gf));\r
         _ = c.ggml_set_f32(y.*.grad, 1.0);\r
 \r
-        c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+        c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
 \r
         std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
         std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n",\r
@@ -287,7 +289,7 @@ pub fn main() !void {
         c.ggml_graph_reset(@constCast(&gf));\r
         _ = c.ggml_set_f32(y.*.grad, 1.0);\r
 \r
-        c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+        c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
 \r
         std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
         std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n",\r
@@ -344,7 +346,7 @@ pub fn main() !void {
         c.ggml_graph_reset(@constCast(&gf));\r
         _ = c.ggml_set_f32(y.*.grad, 1.0);\r
 \r
-        c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+        c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
 \r
         std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
         std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n",\r
@@ -395,7 +397,7 @@ pub fn main() !void {
         c.ggml_graph_reset(@constCast(&gf));\r
         _ = c.ggml_set_f32(y.*.grad, 1.0);\r
 \r
-        c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+        c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
 \r
         std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
         std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n",\r
@@ -425,7 +427,7 @@ pub fn main() !void {
         c.ggml_graph_reset(@constCast(&gf));\r
         _ = c.ggml_set_f32(y.*.grad, 1.0);\r
 \r
-        c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+        c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
 \r
         std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
         std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n",\r