]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : only use default buffer types for the KV cache (#10358)
authorDiego Devesa <redacted>
Sun, 17 Nov 2024 11:25:45 +0000 (12:25 +0100)
committerGitHub <redacted>
Sun, 17 Nov 2024 11:25:45 +0000 (12:25 +0100)
ggml/src/ggml-backend.cpp
src/llama.cpp

index 9a6010d361224f8292fe7fc883cac7f91715f4f6..9dcde8d11952ae4fd6c0d395081f75760c843912 100644 (file)
@@ -689,7 +689,7 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen
 }
 
 static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) {
-    ggml_backend_buffer_t buffer = tensor->buffer;
+    ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
     if (buffer == NULL) {
         return -1;
     }
@@ -722,8 +722,6 @@ static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML
 
 // returns the backend that should be used for the node based on the current locations
 static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) {
-    // TODO: use supports_op to check if the backend supports the op
-
     // assign pre-allocated nodes to their backend
     int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor);
     if (cur_backend_id != -1) {
@@ -742,7 +740,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
 
     if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) {
         // since the tensor is pre-allocated, it cannot be moved to another backend
-        GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation");
+        GGML_ABORT("pre-allocated tensor (%s) in a backend that cannot run the operation", tensor->name);
     }
 
     // graph input
@@ -886,6 +884,9 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
         int * node_backend_id = &tensor_backend_id(node);
+        if (ggml_is_view_op(node->op)) {
+            continue;
+        }
         // do not overwrite user assignments
         if (*node_backend_id == -1) {
             *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);
index 1703104fb3680f2363fb431ebd0749baaa260926..de96959f266326bb0951b69ea634563bf2194ad6 100644 (file)
@@ -3460,21 +3460,13 @@ static bool llama_kv_cache_init(
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
         const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
 
-        const llama_model::buft_list_t * buft_list;
+        ggml_backend_buffer_type_t buft;
         if (offload) {
-            buft_list = model.dev_layer.at(i).buft_list;
+            auto * dev = model.dev_layer.at(i).dev;
+            buft = ggml_backend_dev_buffer_type(dev);
         } else {
-            buft_list = &model.cpu_buft_list;
+            buft = ggml_backend_cpu_buffer_type();
         }
-        ggml_backend_buffer_type_t buft = select_buft(*buft_list,
-            [&](ggml_context * ctx) {
-                ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
-                if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) {
-                    return k;
-                }
-                ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
-                return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type);
-            });
         ggml_context * ctx = ctx_for_buft(buft);
 
         if (!ctx) {