]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml, llama : avoid heavy V transpose + improvements (#775)
authorGeorgi Gerganov <redacted>
Wed, 5 Apr 2023 19:07:33 +0000 (22:07 +0300)
committerGitHub <redacted>
Wed, 5 Apr 2023 19:07:33 +0000 (22:07 +0300)
ggml :

- added ggml_view_3d()
- ggml_view_tensor() now inherits the stride too
- reimplement ggml_cpy() to account for dst stride
- no longer require tensor->data to be memory aligned

llama :

- compute RoPE on 32-bit tensors (should be more accurate)
- store RoPE-ed K in the KV cache
- store transposed V in the KV cache (significant speed-up)
- avoid unnecessary Q copy

ggml.c
ggml.h
llama.cpp

diff --git a/ggml.c b/ggml.c
index 59e84ab45d120af3b1321c270b13556712c76956..a3a3314c746d16b211ce087e8d84d6376bc6203c 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -3219,7 +3219,8 @@ struct ggml_tensor * ggml_new_tensor_impl(
         /*.pad          =*/ { 0 },
     };
 
-    ggml_assert_aligned(result->data);
+    // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
+    //ggml_assert_aligned(result->data);
 
     for (int i = 0; i < n_dims; i++) {
         result->ne[i] = ne[i];
@@ -3620,7 +3621,14 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
 struct ggml_tensor * ggml_view_tensor(
         struct ggml_context * ctx,
         const struct ggml_tensor * src) {
-    return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
+
+    result->nb[0] = src->nb[0];
+    result->nb[1] = src->nb[1];
+    result->nb[2] = src->nb[2];
+    result->nb[3] = src->nb[3];
+
+    return result;
 }
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -4510,6 +4518,37 @@ struct ggml_tensor * ggml_view_2d(
     return result;
 }
 
+// ggml_view_3d
+
+struct ggml_tensor * ggml_view_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                offset) {
+    if (a->grad) {
+        GGML_ASSERT(false); // gradient propagation is not supported
+    }
+
+    const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 };
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
+
+    result->nb[1] = nb1;
+    result->nb[2] = nb2;
+    result->nb[3] = result->nb[2]*ne2;
+
+    result->op   = GGML_OP_VIEW;
+    result->grad = NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store the offset here?
+
+    return result;
+}
+
 // ggml_permute
 
 struct ggml_tensor * ggml_permute(
@@ -4845,7 +4884,6 @@ static void ggml_compute_forward_dup_f16(
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
     GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -4862,85 +4900,96 @@ static void ggml_compute_forward_dup_f16(
     const size_t nb02 = src0->nb[2];
     const size_t nb03 = src0->nb[3];
 
-    if (ggml_is_contiguous(src0) && src0->type == dst->type) {
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
         memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
         return;
     }
 
-    if (src0->nb[0] == sizeof(ggml_fp16_t)) {
-        if (dst->type == GGML_TYPE_F16) {
-            size_t id = 0;
-            const size_t rs = ne00*nb00;
-
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    for (int64_t i01 = 0; i01 < ne01; i01++) {
-                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                        char * dst_ptr = (char *) dst->data + id*rs;
-
-                        memcpy(dst_ptr, src0_ptr, rs);
-
-                        id++;
-                    }
-                }
-            }
-        } else if (dst->type == GGML_TYPE_F32) {
-            size_t id = 0;
-            float * dst_ptr = (float *) dst->data;
-
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    for (int64_t i01 = 0; i01 < ne01; i01++) {
-                        for (int64_t i00 = 0; i00 < ne00; i00++) {
-                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                            dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
-                            id++;
-                        }
-                    }
+    if (src0->type == dst->type &&
+        src0->ne[0] == dst->ne[0] &&
+        src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
                 }
             }
-        } else {
-            GGML_ASSERT(false); // TODO: implement
         }
-    } else {
-        //printf("%s: this is not optimal - fix me\n", __func__);
+        return;
+    }
 
-        if (dst->type == GGML_TYPE_F32) {
-            size_t id = 0;
-            float * dst_ptr = (float *) dst->data;
+    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
 
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    for (int64_t i01 = 0; i01 < ne01; i01++) {
-                        for (int64_t i00 = 0; i00 < ne00; i00++) {
-                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+    // dst counters
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
 
-                            dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
-                            id++;
+    if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
+
+                        if (++i10 == ne00) {
+                            i10 = 0;
+                            if (++i11 == ne01) {
+                                i11 = 0;
+                                if (++i12 == ne02) {
+                                    i12 = 0;
+                                    if (++i13 == ne03) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
                         }
                     }
                 }
             }
-        } else if (dst->type == GGML_TYPE_F16) {
-            size_t id = 0;
-            ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    for (int64_t i01 = 0; i01 < ne01; i01++) {
-                        for (int64_t i00 = 0; i00 < ne00; i00++) {
-                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                            dst_ptr[id] = *src0_ptr;
-                            id++;
+        }
+    } else if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
+
+                        if (++i10 == ne00) {
+                            i10 = 0;
+                            if (++i11 == ne01) {
+                                i11 = 0;
+                                if (++i12 == ne02) {
+                                    i12 = 0;
+                                    if (++i13 == ne03) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
                         }
                     }
                 }
             }
-        } else {
-            GGML_ASSERT(false); // TODO: implement
         }
+    } else {
+        GGML_ASSERT(false); // TODO: implement
     }
 }
 
@@ -4949,7 +4998,6 @@ static void ggml_compute_forward_dup_f32(
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
     GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -4966,85 +5014,76 @@ static void ggml_compute_forward_dup_f32(
     const size_t nb02 = src0->nb[2];
     const size_t nb03 = src0->nb[3];
 
-    if (ggml_is_contiguous(src0) && src0->type == dst->type) {
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
         memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
         return;
     }
 
-    if (src0->nb[0] == sizeof(float)) {
-        if (dst->type == GGML_TYPE_F32) {
-            size_t id = 0;
-            const size_t rs = ne00*nb00;
-
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    for (int64_t i01 = 0; i01 < ne01; i01++) {
-                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                        char * dst_ptr = (char *) dst->data + id*rs;
+    // dst counters
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
 
-                        memcpy(dst_ptr, src0_ptr, rs);
-
-                        id++;
-                    }
-                }
-            }
-        } else if (dst->type == GGML_TYPE_F16) {
-            size_t id = 0;
-            ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    for (int64_t i01 = 0; i01 < ne01; i01++) {
-                        for (int64_t i00 = 0; i00 < ne00; i00++) {
-                            const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                            dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
-                            id++;
+    if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(float));
+
+                        if (++i10 == dst->ne[0]) {
+                            i10 = 0;
+                            if (++i11 == dst->ne[1]) {
+                                i11 = 0;
+                                if (++i12 == dst->ne[2]) {
+                                    i12 = 0;
+                                    if (++i13 == dst->ne[3]) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
                         }
                     }
                 }
             }
-        } else {
-            GGML_ASSERT(false); // TODO: implement
         }
-    } else {
-        //printf("%s: this is not optimal - fix me\n", __func__);
-
-        if (dst->type == GGML_TYPE_F32) {
-            size_t id = 0;
-            float * dst_ptr = (float *) dst->data;
-
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    for (int64_t i01 = 0; i01 < ne01; i01++) {
-                        for (int64_t i00 = 0; i00 < ne00; i00++) {
-                            const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                            dst_ptr[id] = *src0_ptr;
-                            id++;
-                        }
-                    }
-                }
-            }
-        } else if (dst->type == GGML_TYPE_F16) {
-            size_t id = 0;
-            ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    for (int64_t i01 = 0; i01 < ne01; i01++) {
-                        for (int64_t i00 = 0; i00 < ne00; i00++) {
-                            const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                            dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
-                            id++;
+    } else if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
+
+                        if (++i10 == dst->ne[0]) {
+                            i10 = 0;
+                            if (++i11 == dst->ne[1]) {
+                                i11 = 0;
+                                if (++i12 == dst->ne[2]) {
+                                    i12 = 0;
+                                    if (++i13 == dst->ne[3]) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
                         }
                     }
                 }
             }
-        } else {
-            GGML_ASSERT(false); // TODO: implement
         }
+    } else {
+        GGML_ASSERT(false); // TODO: implement
     }
 }
 
diff --git a/ggml.h b/ggml.h
index ad962b109ea89c1618d8f3522c891000b91f6a35..3c94efc35c2f78dfb5342cc098a36faee6c957a5 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -558,6 +558,16 @@ struct ggml_tensor * ggml_view_2d(
         size_t                nb1, // row stride in bytes
         size_t                offset);
 
+struct ggml_tensor * ggml_view_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        size_t                nb1, // row   stride in bytes
+        size_t                nb2, // slice stride in bytes
+        size_t                offset);
+
 struct ggml_tensor * ggml_permute(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
index e4517959a8c7788208c6495ec79a6a1c50f585ac..581a8399d0229814aab9866b19d1f568dff7794e 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -810,37 +810,35 @@ static bool llama_eval_internal(
 
         // self-attention
         {
-            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
-            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
-            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+            // compute Q and K and RoPE them
+            struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
+            struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
 
             // store key and value to memory
-            if (N >= 1) {
+            {
+                // compute the transposed [N, n_embd] V matrix
+                struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
+
                 struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
-                struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_embd, (ggml_element_size(kv_self.v)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
+                        (   n_ctx)*ggml_element_size(kv_self.v),
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
 
+                // important: storing RoPE-ed version of K in the KV cache!
                 ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
                 ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
             }
 
-            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_rope(ctx0,
-                            ggml_cpy(ctx0,
-                                Qcur,
-                                ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
-                            n_past, n_rot, 0),
+                        Qcur,
                         0, 2, 1, 3);
 
-            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
             struct ggml_tensor * K =
                 ggml_permute(ctx0,
-                        ggml_rope(ctx0,
-                            ggml_reshape_3d(ctx0,
-                                ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
-                                n_embd/n_head, n_head, n_past + N),
-                            n_past, n_rot, 1),
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
+                            n_embd/n_head, n_head, n_past + N),
                         0, 2, 1, 3);
 
             // K * Q
@@ -858,18 +856,23 @@ static bool llama_eval_internal(
             // KQ = soft_max(KQ_masked)
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
 
-            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
-            struct ggml_tensor * V_trans =
-                ggml_cpy(ctx0,
-                    ggml_permute(ctx0,
-                            ggml_reshape_3d(ctx0,
-                                ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
-                                n_embd/n_head, n_head, n_past + N),
-                            1, 2, 0, 3),
-                    ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
+            // split cached V into n_head heads
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, kv_self.v,
+                        n_past + N, n_embd/n_head, n_head,
+                        n_ctx*ggml_element_size(kv_self.v),
+                        n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
+                        il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
 
-            // KQV = transpose(V) * KQ_soft_max
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+#if 1
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+#else
+            // make V contiguous in memory to speed up the matmul, however we waste time on the copy
+            // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
+            // is there a better way?
+            struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
+#endif
 
             // KQV_merged = KQV.permute(0, 2, 1, 3)
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@@ -955,9 +958,13 @@ static bool llama_eval_internal(
     ggml_build_forward_expand(&gf, inpL);
     ggml_graph_compute       (ctx0, &gf);
 
+    // print timing information per ggml operation (for debugging purposes)
+    // requires GGML_PERF to be defined
+    //ggml_graph_print(&gf);
+
+    // plot the computation graph in dot format (for debugging purposes)
     //if (n_past%100 == 0) {
-    //    ggml_graph_print   (&gf);
-    //    ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
+    //    ggml_graph_dump_dot(&gf, NULL, "llama.dot");
     //}
 
     //embd_w.resize(n_vocab*N);