]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: add names to tensors (#1268)
authorslaren <redacted>
Tue, 2 May 2023 14:03:00 +0000 (16:03 +0200)
committerGitHub <redacted>
Tue, 2 May 2023 14:03:00 +0000 (16:03 +0200)
* ggml: add names to tensors

* minor improvements to dot file formatting

ggml.c
ggml.h
llama.cpp

diff --git a/ggml.c b/ggml.c
index bce7a7a57e93967f7b04721b83035925d94f0867..6a9695e2351809dc2b2c4a6d91514303a27caf8d 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4541,6 +4541,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
         /*.perf_cycles  =*/ 0,
         /*.perf_time_us =*/ 0,
         /*.data         =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
+        /*.name         =*/ { 0 },
         /*.pad          =*/ { 0 },
     };
 
@@ -4895,6 +4896,15 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
     return (float *)(tensor->data);
 }
 
+const char * ggml_get_name(const struct ggml_tensor * tensor) {
+    return tensor->name;
+}
+
+void ggml_set_name(struct ggml_tensor * tensor, const char * name) {
+    strncpy(tensor->name, name, sizeof(tensor->name));
+    tensor->name[sizeof(tensor->name) - 1] = '\0';
+}
+
 struct ggml_tensor * ggml_view_tensor(
         struct ggml_context * ctx,
         const struct ggml_tensor * src) {
@@ -5994,6 +6004,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
     //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
     struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
+    ggml_set_name(b, "n_past");
 
     result->op   = GGML_OP_DIAG_MASK_INF;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6051,6 +6062,7 @@ struct ggml_tensor * ggml_rope(
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_dims;
     ((int32_t *) b->data)[2] = mode;
+    ggml_set_name(b, "n_past, n_dims, mode");
 
     result->op   = GGML_OP_ROPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -12118,10 +12130,16 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
             snprintf(color, sizeof(color), "white");
         }
 
-        fprintf(fp, "  \"%p\" [ \
-style = filled; fillcolor = %s; shape = record; \
-label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
-                (void *) node, color,
+        fprintf(fp, "  \"%p\" [ "
+                    "style = filled; fillcolor = %s; shape = record; "
+                    "label=\"",
+                (void *) node, color);
+
+        if (strlen(node->name) > 0) {
+            fprintf(fp, "%s |", node->name);
+        }
+
+        fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s",
                 i, node->ne[0], node->ne[1],
                 GGML_OP_SYMBOL[node->op]);
 
@@ -12137,18 +12155,26 @@ label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
 
         snprintf(color, sizeof(color), "pink");
 
+        fprintf(fp, "  \"%p\" [ "
+                    "style = filled; fillcolor = %s; shape = record; "
+                    "label=\"<x>",
+                (void *) node, color);
+
+        if (strlen(node->name) > 0) {
+                fprintf(fp, "%s | ", node->name);
+        }
         if (ggml_nelements(node) == 1) {
-            fprintf(fp, "  \"%p\" [ \
-style = filled; fillcolor = %s; shape = record; \
-label=\"<x>%.1e\"; ]\n",
-                    (void *) node, color, (double)ggml_get_f32_1d(node, 0));
-        } else {
-            fprintf(fp, "  \"%p\" [ \
-style = filled; fillcolor = %s; shape = record; \
-label=\"<x>CONST %d [%" PRId64 ", %" PRId64 "]\"; ]\n",
-                    (void *) node, color,
-                    i, node->ne[0], node->ne[1]);
+            if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
+                fprintf(fp, "%d", ggml_get_i32_1d(node, 0));
+            }
+            else {
+                fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, 0));
+            }
+        }
+        else {
+            fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
         }
+        fprintf(fp, "\"; ]\n");
     }
 
     for (int i = 0; i < gb->n_nodes; i++) {
diff --git a/ggml.h b/ggml.h
index ef5a048c3b7e4f6da1ad4f4cbbdc69dd88a3a213..508dd69b41713209d2339bd06148fcc96ceebec1 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -350,7 +350,10 @@ extern "C" {
         int64_t perf_time_us;
 
         void * data;
-        char padding[8];
+
+        char name[32];
+
+        char padding[8]; // TODO: remove and add padding to name?
     };
 
     // computation graph
@@ -473,6 +476,9 @@ extern "C" {
     GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
     GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 
+    GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor);
+    GGML_API void         ggml_set_name(struct ggml_tensor * tensor, const char * name);
+
     //
     // operations on tensors with backpropagation
     //
index 868a58a8b0b93a0c881413f55331b6a6a8aced7a..b8751be7bf3c75b838a40454cea37fd713d9aae3 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -659,6 +659,7 @@ struct llama_model_loader {
             LLAMA_ASSERT(lt.ne.size() == 1);
             tensor = ggml_new_tensor_1d(ggml_ctx, lt.type, lt.ne.at(0));
         }
+        ggml_set_name(tensor, lt.name.c_str());
         LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor
         lt.ggml_tensor = tensor;
         num_ggml_tensors_created++;
@@ -798,6 +799,8 @@ static bool kv_cache_init(
 
     cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
     cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+    ggml_set_name(cache.k, "cache_k");
+    ggml_set_name(cache.v, "cache_v");
 
     return true;
 }
@@ -1084,6 +1087,7 @@ static bool llama_eval_internal(
     gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    ggml_set_name(embd, "embd");
     memcpy(embd->data, tokens, N*ggml_element_size(embd));
 
     struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
@@ -1110,6 +1114,8 @@ static bool llama_eval_internal(
             // compute Q and K and RoPE them
             struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
             struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
+            ggml_set_name(Qcur, "Qcur");
+            ggml_set_name(Kcur, "Kcur");
 
             // store key and value to memory
             {
@@ -1130,6 +1136,7 @@ static bool llama_eval_internal(
                 ggml_permute(ctx0,
                         Qcur,
                         0, 2, 1, 3);
+            ggml_set_name(Q, "Q");
 
             struct ggml_tensor * K =
                 ggml_permute(ctx0,
@@ -1137,21 +1144,26 @@ static bool llama_eval_internal(
                             ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
                             n_embd/n_head, n_head, n_past + N),
                         0, 2, 1, 3);
+            ggml_set_name(K, "K");
 
             // K * Q
             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            ggml_set_name(KQ, "KQ");
 
             // KQ_scaled = KQ / sqrt(n_embd/n_head)
-            struct ggml_tensor * KQ_scaled =
-                ggml_scale(ctx0,
-                        KQ,
-                        ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
+            struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
+            ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");
+
+            struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
+            ggml_set_name(KQ_scaled, "KQ_scaled");
 
             // KQ_masked = mask_past(KQ_scaled)
             struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+            ggml_set_name(KQ_masked, "KQ_masked");
 
             // KQ = soft_max(KQ_masked)
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+            ggml_set_name(KQ_soft_max, "KQ_soft_max");
 
             // split cached V into n_head heads
             struct ggml_tensor * V =
@@ -1160,9 +1172,11 @@ static bool llama_eval_internal(
                         n_ctx*ggml_element_size(kv_self.v),
                         n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
                         il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
+            ggml_set_name(V, "V");
 
 #if 1
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+            ggml_set_name(KQV, "KQV");
 #else
             // make V contiguous in memory to speed up the matmul, however we waste time on the copy
             // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
@@ -1173,11 +1187,13 @@ static bool llama_eval_internal(
 
             // KQV_merged = KQV.permute(0, 2, 1, 3)
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            ggml_set_name(KQV_merged, "KQV_merged");
 
             // cur = KQV_merged.contiguous().view(n_embd, N)
             cur = ggml_cpy(ctx0,
                     KQV_merged,
                     ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+            ggml_set_name(cur, "KQV_merged_contiguous");
 
             // projection (no bias)
             cur = ggml_mul_mat(ctx0,