]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : fix bug in ggml_compute_forward_sum_f32 (#1162)
authorxaedes <redacted>
Mon, 24 Apr 2023 21:02:02 +0000 (23:02 +0200)
committerGitHub <redacted>
Mon, 24 Apr 2023 21:02:02 +0000 (23:02 +0200)
The sum over all rows is now computed instead of just the last row

ggml.c

diff --git a/ggml.c b/ggml.c
index 6e46c0e5ad1dab8dba7c7111faa6502011340189..85058899574bed68e99abca882518ecda9288814 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -6779,15 +6779,20 @@ static void ggml_compute_forward_sum_f32(
     const size_t nb02 = src0->nb[2];
     const size_t nb03 = src0->nb[3];
 
+    ggml_float sum     = 0;
+    float      row_sum = 0;
+
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
             for (int64_t i01 = 0; i01 < ne01; i01++) {
                 ggml_vec_sum_f32(ne00,
-                        (float *) (dst->data),
+                        &row_sum,
                         (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+                sum += row_sum;
             }
         }
     }
+    ((float *) dst->data)[0] = sum;
 }
 
 static void ggml_compute_forward_sum(