]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
examples : fix simple (#770)
authorBryan Lozano <redacted>
Fri, 22 Mar 2024 07:17:34 +0000 (00:17 -0700)
committerGitHub <redacted>
Fri, 22 Mar 2024 07:17:34 +0000 (09:17 +0200)
* Update README.md

Correcting matrix multiplication expected result.

* Update simple-ctx.cpp

Fix incorrect striding through output.

* simple : update readme

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/simple/README.md
examples/simple/simple-ctx.cpp

index ba3a4fd1a4e0889828f33e6d11527f7433e8314f..28c549fc9b8f184c3b2d68f4f8ed8fd8af83dd01 100644 (file)
@@ -2,6 +2,12 @@
 
 This example simply performs a matrix multiplication, solely for the purpose of demonstrating a basic usage of ggml and backend handling. The code is commented to help understand what each part does.
 
+Traditional matrix multiplication goes like this (multiply row-by-column):
+
+$$
+A \times B = C
+$$
+
 $$
 \begin{bmatrix}
 2 & 8 \\
@@ -16,9 +22,39 @@ $$
 \end{bmatrix}
 \=
 \begin{bmatrix}
-60 & 110 & 54 & 29 \\
-55 & 90 & 126 & 28 \\
-50 & 54 & 42 & 64 \\
+60 & 90 & 42 \\
+55 & 54 & 29 \\
+50 &  54 & 28 \\
+110 & 126 & 64 \\
+\end{bmatrix}
+$$
+
+In `ggml`, we pass the matrix $B$ in transposed form and multiply row-by-row. The result $C$ is also transposed:
+
+$$
+ggml\\_mul\\_mat(A, B^T) = C^T
+$$
+
+$$
+ggml\\_mul\\_mat(
+\begin{bmatrix}
+2 & 8 \\
+5 & 1 \\
+4 & 2 \\
+8 & 6 \\
+\end{bmatrix}
+,
+\begin{bmatrix}
+10 & 5 \\
+9 & 9 \\
+5 & 4 \\
+\end{bmatrix}
+)
+\=
+\begin{bmatrix}
+60 & 55 & 50 & 110 \\
+90 & 54 & 54 & 126 \\
+42 & 29 & 28 & 64 \\
 \end{bmatrix}
 $$
 
index d331a4c1ffa7ab97303e08efea1fac8a95f01c3f..b2d4e4ba5be156e4c4acc9746a80fbd1ac2c39b9 100644 (file)
@@ -104,9 +104,9 @@ int main(void) {
     memcpy(out_data.data(), result->data, ggml_nbytes(result));
 
     // expected result:
-    // [ 60.00 110.00 54.00 29.00
-    //  55.00 90.00 126.00 28.00
-    //  50.00 54.00 42.00 64.00 ]
+    // [ 60.00 55.00 50.00 110.00
+    //   90.00 54.00 54.00 126.00
+    //   42.00 29.00 28.00 64.00 ]
 
     printf("mul mat (%d x %d) (transposed result):\n[", (int) result->ne[0], (int) result->ne[1]);
     for (int j = 0; j < result->ne[1] /* rows */; j++) {
@@ -115,7 +115,7 @@ int main(void) {
         }
 
         for (int i = 0; i < result->ne[0] /* cols */; i++) {
-            printf(" %.2f", out_data[i * result->ne[1] + j]);
+            printf(" %.2f", out_data[j * result->ne[0] + i]);
         }
     }
     printf(" ]\n");