]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : improve graph build time via hash table lookup (#2329)
authorslaren <redacted>
Tue, 25 Jul 2023 12:32:20 +0000 (14:32 +0200)
committerGitHub <redacted>
Tue, 25 Jul 2023 12:32:20 +0000 (15:32 +0300)
* improve graph build time

* ggml_tensor : use 1 bit per flag

* use a hash table instead

ggml.c
ggml.h
llama.cpp

diff --git a/ggml.c b/ggml.c
index 11226c834de7bc8d1a453fe57516fd1dd699c834..d2f5e72751056dd887af3ef6f96d7694e9db26cd 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -15665,6 +15665,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
     }
 }
 
+static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
+
+static size_t hash(void * p) {
+    return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
+}
+
+static bool hash_insert(void * hash_table[], void * p) {
+    size_t h = hash(p);
+
+    // linear probing
+    size_t i = h;
+    while (hash_table[i] != NULL && hash_table[i] != p) {
+        i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
+        if (i == h) {
+            // hash table is full
+            GGML_ASSERT(false);
+        }
+    }
+
+    if (hash_table[i] == p) {
+        return true;
+    }
+
+    // insert
+    hash_table[i] = p;
+    return false;
+}
+
 static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
     if (node->grad == NULL) {
         // this usually happens when we generate intermediate nodes from constants in the backward pass
@@ -15675,16 +15703,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
     }
 
     // check if already visited
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        if (cgraph->nodes[i] == node) {
-            return;
-        }
-    }
-
-    for (int i = 0; i < cgraph->n_leafs; i++) {
-        if (cgraph->leafs[i] == node) {
-            return;
-        }
+    if (hash_insert(cgraph->visited_hash_table, node)) {
+        return;
     }
 
     for (int i = 0; i < GGML_MAX_SRC; ++i) {
@@ -15747,6 +15767,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
         /*.nodes        =*/ { NULL },
         /*.grads        =*/ { NULL },
         /*.leafs        =*/ { NULL },
+        /*.hash_table   =*/ { NULL },
         /*.perf_runs    =*/ 0,
         /*.perf_cycles  =*/ 0,
         /*.perf_time_us =*/ 0,
@@ -15788,7 +15809,7 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg
 
         if (node->is_param) {
             GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
-            ggml_build_forward_impl(&result, node->grad, true);
+            ggml_build_forward_expand(&result, node->grad);
         }
     }
 
diff --git a/ggml.h b/ggml.h
index 1870b62e8aa1fff1ee13b97aa47e5be8aec8b683..c309f1361c6f6c5494369b19fb611754f46b7dda 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -442,7 +442,7 @@ extern "C" {
 
         void * extra; // extra things e.g. for ggml-cuda.cu
 
-        char padding[8];
+        char padding[4];
     };
 
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@@ -463,6 +463,11 @@ extern "C" {
         void * abort_callback_data;
     };
 
+    // next prime after GGML_MAX_NODES
+    // #define GGML_GRAPH_HASHTABLE_SIZE 4099
+    // next prime after GGML_MAX_NODES * 2 (nodes + leafs)
+    #define GGML_GRAPH_HASHTABLE_SIZE 8273
+
     // computation graph
     struct ggml_cgraph {
         int n_nodes;
@@ -472,6 +477,8 @@ extern "C" {
         struct ggml_tensor * grads[GGML_MAX_NODES];
         struct ggml_tensor * leafs[GGML_MAX_NODES];
 
+        void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
+
         // performance
         int     perf_runs;
         int64_t perf_cycles;
index 2d737bbcebc2f8ba0f57bebcc975a24af02ca997..febefbacf313a0023bd4b40346d9a108f4b028a1 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1714,6 +1714,8 @@ static bool llama_eval_internal(
     // run the computation
     ggml_build_forward_expand(&gf, cur);
 
+    // fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs);
+
 #if GGML_USE_MPI
     ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer);
 #endif