]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : run all KQV ops on the CPU with no KV offload (#5049)
authorslaren <redacted>
Sat, 20 Jan 2024 15:05:49 +0000 (16:05 +0100)
committerGitHub <redacted>
Sat, 20 Jan 2024 15:05:49 +0000 (17:05 +0200)
ggml-ci

ggml-backend.c
llama.cpp

index ef518dae0909b9a9d8ee3c1b886f701222b712c4..423512defc132450efad523c73bf1a71f5109147 100644 (file)
@@ -1191,6 +1191,24 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
                 ggml_tallocr_t src_allocr = node_allocr(src);
                 GGML_ASSERT(src_allocr != NULL); // all inputs should be assigned by now
                 if (src_allocr != node_allocr) {
+                    // create a copy of the input in the split's backend
+                    size_t id = hash_id(src);
+                    if (sched->node_copies[id][cur_backend_id] == NULL) {
+                        ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
+                        struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+                        ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
+
+                        sched->node_copies[id][cur_backend_id] = tensor_copy;
+                        node_allocr(tensor_copy) = cur_allocr;
+                        SET_CAUSE(tensor_copy, "4.cpy");
+
+                        int n_inputs = sched->splits[cur_split].n_inputs++;
+                        GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
+                        sched->splits[cur_split].inputs[n_inputs] = src;
+                    }
+                    node->src[j] = sched->node_copies[id][cur_backend_id];
+
+#if 0
                     // check if the input is already in the split
                     bool found = false;
                     for (int k = 0; k < sched->splits[cur_split].n_inputs; k++) {
@@ -1206,19 +1224,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
                         GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
                         sched->splits[cur_split].inputs[n_inputs] = src;
                     }
-
-                    // create a copy of the input in the split's backend
-                    size_t id = hash_id(src);
-                    if (sched->node_copies[id][cur_backend_id] == NULL) {
-                        ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
-                        struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
-                        ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
-
-                        sched->node_copies[id][cur_backend_id] = tensor_copy;
-                        node_allocr(tensor_copy) = cur_allocr;
-                        SET_CAUSE(tensor_copy, "4.cpy");
-                    }
-                    node->src[j] = sched->node_copies[id][cur_backend_id];
+#endif
                 }
             }
         }
@@ -1333,7 +1339,7 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
         uint64_t compute_start_us = ggml_time_us();
         if (!sched->callback_eval) {
             ggml_backend_graph_compute(split_backend, &split->graph);
-          //ggml_backend_synchronize(split_backend); // necessary to measure compute time
+            //ggml_backend_synchronize(split_backend); // necessary to measure compute time
         } else {
             // similar to ggml_backend_compare_graph_backend
             for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
index 90579ac85dabd5fe6410eef408b9f593aa977307..909ad4ad854c43ff6e0389ad2bdf9e6c3ff48915 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -4315,6 +4315,7 @@ static struct ggml_tensor * llm_build_kqv(
           const llama_model & model,
         const llama_hparams & hparams,
        const llama_kv_cache & kv,
+         struct ggml_cgraph * graph,
          struct ggml_tensor * wo,
          struct ggml_tensor * wo_b,
          struct ggml_tensor * q_cur,
@@ -4393,6 +4394,8 @@ static struct ggml_tensor * llm_build_kqv(
     struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
     cb(cur, "kqv_merged_cont", il);
 
+    ggml_build_forward_expand(graph, cur);
+
     cur = ggml_mul_mat(ctx, wo, cur);
     if (wo_b) {
         cb(cur, "kqv_wo", il);
@@ -4405,6 +4408,44 @@ static struct ggml_tensor * llm_build_kqv(
     return cur;
 }
 
+static struct ggml_tensor * llm_build_kv(
+        struct ggml_context * ctx,
+          const llama_model & model,
+        const llama_hparams & hparams,
+       const llama_kv_cache & kv,
+         struct ggml_cgraph * graph,
+         struct ggml_tensor * wo,
+         struct ggml_tensor * wo_b,
+         struct ggml_tensor * k_cur,
+         struct ggml_tensor * v_cur,
+         struct ggml_tensor * q_cur,
+         struct ggml_tensor * kq_mask,
+                    int64_t   n_ctx,
+                    int32_t   n_tokens,
+                    int32_t   kv_head,
+                    int32_t   n_kv,
+                    float     max_alibi_bias,
+                    float     kq_scale,
+         const llm_build_cb & cb,
+                    int       il) {
+
+    // these nodes are added to the graph together so that they are not reordered
+    // by doing so, the number of splits in the graph is reduced
+    ggml_build_forward_expand(graph, k_cur);
+    ggml_build_forward_expand(graph, v_cur);
+    ggml_build_forward_expand(graph, q_cur);
+
+    llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
+
+    struct ggml_tensor * cur;
+    cur  = llm_build_kqv(ctx, model, hparams, kv, graph,
+            wo, wo_b,
+            q_cur, kq_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, kq_scale, cb, il);
+    cb(cur, "kqv_out", il);
+
+    return cur;
+}
+
 struct llm_build_context {
     const llama_model    & model;
     const llama_hparams  & hparams;
@@ -4562,12 +4603,6 @@ struct llm_build_context {
                     cb(Vcur, "Vcur", il);
                 }
 
-                // these nodes are added to the graph together so that they are not reordered
-                // by doing so, the number of splits in the graph is reduced
-                ggml_build_forward_expand(gf, Qcur);
-                ggml_build_forward_expand(gf, Kcur);
-                ggml_build_forward_expand(gf, Vcur);
-
                 Qcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
                     hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
@@ -4582,11 +4617,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4763,14 +4796,13 @@ struct llm_build_context {
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
                 // apply ALiBi for 13B model
                 const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
 
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4892,11 +4924,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4993,11 +5023,9 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5200,12 +5228,9 @@ struct llm_build_context {
                         );
                 cb(Vcur, "Vcur", il);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                // TODO: not tested, could be broken
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Q, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Q, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5292,11 +5317,9 @@ struct llm_build_context {
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 cb(Qcur, "Qcur", il);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5390,11 +5413,9 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5485,11 +5506,9 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5597,11 +5616,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5714,11 +5731,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5837,11 +5852,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5966,11 +5979,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f, cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6071,11 +6082,9 @@ struct llm_build_context {
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
             struct ggml_tensor * sa_out = cur;
@@ -6172,11 +6181,9 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6283,11 +6290,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
-
-                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6355,6 +6360,14 @@ static struct ggml_cgraph * llama_build_graph(
             ggml_set_name(cur, name);
         }
 
+
+        if (!lctx.cparams.offload_kqv) {
+            if (strcmp(name, "kqv_merged_cont") == 0) {
+                // all nodes between the KV store and the attention output are run on the CPU
+                ggml_backend_sched_set_node_backend(lctx.sched, cur, lctx.backend_cpu);
+            }
+        }
+
         //
         // allocate input tensors and set input data
         //