]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
backend : add eval callback (#4935)
authorGeorgi Gerganov <redacted>
Wed, 17 Jan 2024 16:39:41 +0000 (18:39 +0200)
committerGitHub <redacted>
Wed, 17 Jan 2024 16:39:41 +0000 (18:39 +0200)
* backend : add eval callback

ggml-ci

* backend : group nodes in a single compute when user don't need them

* backend : clean-up the implementation

ggml-ci

* simple : do not perform tensor data copy if not needed

* simple : fix

* simple : no need for ggml_is_contiguous + fix bool parse

* llama : fix callback placement in llama_context_params

* backend : avoid double-ask callback calls

* simple : restore examples, imatrix will serve as a demo

ggml-backend.c
ggml-backend.h
llama.cpp
llama.h

index f5424fb90411709bda6ebdf8cf721fda289e027c..4266250f926eee602c2aaf650190dde72a22fd47 100644 (file)
@@ -802,6 +802,9 @@ struct ggml_backend_sched {
     __attribute__((aligned(GGML_MEM_ALIGN)))
     #endif
     char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
+
+    ggml_backend_sched_eval_callback callback_eval;
+    void * callback_eval_user_data;
 };
 
 #define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
@@ -1324,9 +1327,38 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
         ggml_graph_dump_dot(split->graph, NULL, split_filename);
 #endif
 
+
         uint64_t compute_start_us = ggml_time_us();
-        ggml_backend_graph_compute(split_backend, &split->graph);
-        //ggml_backend_synchronize(split_backend); // necessary to measure compute time
+        if (!sched->callback_eval) {
+            ggml_backend_graph_compute(split_backend, &split->graph);
+          //ggml_backend_synchronize(split_backend); // necessary to measure compute time
+        } else {
+            // similar to ggml_backend_compare_graph_backend
+            for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
+                struct ggml_tensor * t = split->graph.nodes[j0];
+
+                // check if the user needs data from this node
+                bool need = sched->callback_eval(t, true, sched->callback_eval_user_data);
+
+                int j1 = j0;
+
+                // determine the range [j0, j1] of nodes that can be computed together
+                while (!need && j1 < split->graph.n_nodes - 1) {
+                    t = split->graph.nodes[++j1];
+                    need = sched->callback_eval(t, true, sched->callback_eval_user_data);
+                }
+
+                struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
+
+                ggml_backend_graph_compute(split_backend, &gv);
+
+                if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
+                    break;
+                }
+
+                j0 = j1;
+            }
+        }
         uint64_t compute_end_us = ggml_time_us();
         compute_us[split_backend_id] += compute_end_us - compute_start_us;
     }
@@ -1431,6 +1463,12 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
     sched_reset(sched);
 }
 
+
+void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
+    sched->callback_eval = callback;
+    sched->callback_eval_user_data = user_data;
+}
+
 int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
     return sched->n_splits;
 }
index 12b4b4ab749357fae65d07fad2f2fa2157a01ccb..ab4ad773ffbcebc9b30210ebb650770592d7ebb9 100644 (file)
@@ -148,6 +148,14 @@ extern "C" {
     struct ggml_backend_sched;
     typedef struct ggml_backend_sched * ggml_backend_sched_t;
 
+    // when ask == true, the scheduler wants to know if the user wants to observe this node
+    // this allows the scheduler to batch nodes together in order to evaluate them in a single call
+    //
+    // when ask == false, the scheduler is passing the node tensor to the user for observation
+    // if the user returns false, the scheduler will cancel the graph compute
+    //
+    typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
+
     // Initialize a backend scheduler
     GGML_API ggml_backend_sched_t  ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
     GGML_API void                  ggml_backend_sched_free(ggml_backend_sched_t sched);
@@ -168,6 +176,9 @@ extern "C" {
     // Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs
     GGML_API void                  ggml_backend_sched_reset(ggml_backend_sched_t sched);
 
+    // Set a callback to be called for each resulting node during graph compute
+    GGML_API void                  ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
+
     //
     // Utils
     //
index 2c5983c67f67197f131ca0a51ef88192c9324d85..81829b13e4e94a5ed1c21c6cb1aa8a6b69eb1272 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1393,6 +1393,9 @@ struct llama_cparams {
 
     bool mul_mat_q;
     bool offload_kqv;
+
+    ggml_backend_sched_eval_callback cb_eval;
+    void * cb_eval_user_data;
 };
 
 struct llama_layer {
@@ -6254,6 +6257,7 @@ static int llama_decode_internal(
     //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
     ggml_backend_sched_reset(lctx.sched);
+    ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
     ggml_cgraph * gf = llama_build_graph(lctx, batch);
 
@@ -9276,6 +9280,8 @@ struct llama_context_params llama_context_default_params() {
         /*.yarn_beta_fast              =*/ 32.0f,
         /*.yarn_beta_slow              =*/ 1.0f,
         /*.yarn_orig_ctx               =*/ 0,
+        /*.cb_eval                     =*/ nullptr,
+        /*.cb_eval_user_data           =*/ nullptr,
         /*.type_k                      =*/ GGML_TYPE_F16,
         /*.type_v                      =*/ GGML_TYPE_F16,
         /*.mul_mat_q                   =*/ true,
@@ -9416,6 +9422,9 @@ struct llama_context * llama_new_context_with_model(
                                hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
                                                               hparams.n_ctx_train;
 
+    cparams.cb_eval           = params.cb_eval;
+    cparams.cb_eval_user_data = params.cb_eval_user_data;
+
     auto rope_scaling_type = params.rope_scaling_type;
     if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
         rope_scaling_type = hparams.rope_scaling_type_train;
diff --git a/llama.h b/llama.h
index a570b0d6968fbf0e9dede1848e4799754ef76075..e268d7a1d0cc9bf9dd8290fdbb6da2f2959589d5 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -2,6 +2,7 @@
 #define LLAMA_H
 
 #include "ggml.h"
+#include "ggml-backend.h"
 #ifdef GGML_USE_CUBLAS
 #include "ggml-cuda.h"
 #define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
@@ -231,6 +232,9 @@ extern "C" {
         float    yarn_beta_slow;   // YaRN high correction dim
         uint32_t yarn_orig_ctx;    // YaRN original context size
 
+        ggml_backend_sched_eval_callback cb_eval;
+        void * cb_eval_user_data;
+
         enum ggml_type type_k; // data type for K cache
         enum ggml_type type_v; // data type for V cache