]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llamafile : use 64-bit integers in sgemm (#6928)
authorJustine Tunney <redacted>
Fri, 26 Apr 2024 14:05:33 +0000 (10:05 -0400)
committerGitHub <redacted>
Fri, 26 Apr 2024 14:05:33 +0000 (17:05 +0300)
sgemm.cpp
sgemm.h

index 531e12af361ccf862b216ad90147c6bf3100cbda..4e0159804e8166b118d51d200dfa727a148f9b42 100644 (file)
--- a/sgemm.cpp
+++ b/sgemm.cpp
@@ -50,7 +50,6 @@
 #pragma GCC diagnostic ignored "-Wignored-attributes"
 
 #include "sgemm.h"
-#include <algorithm>
 #include "ggml-impl.h"
 #include "ggml-quants.h"
 
@@ -243,23 +242,23 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
 template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
 class tinyBLAS {
   public:
-    tinyBLAS(int k,
-             const TA *A, int lda,
-             const TB *B, int ldb,
-             TC *C, int ldc,
+    tinyBLAS(int64_t k,
+             const TA *A, int64_t lda,
+             const TB *B, int64_t ldb,
+             TC *C, int64_t ldc,
              int ith, int nth)
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
-    void matmul(int m, int n, int task) {
+    void matmul(int64_t m, int64_t n, int task) {
         if (task == GGML_TASK_TYPE_COMPUTE)
             mnpack(0, m, 0, n);
     }
 
   private:
-    NOINLINE void mnpack(int m0, int m, int n0, int n) {
-        int mc, nc, mp, np;
-        switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) {
+    NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t mc, nc, mp, np;
+        switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
 #if VECTOR_REGISTERS == 32
         case 0x55:
             mc = 5;
@@ -409,27 +408,27 @@ class tinyBLAS {
     }
 
     template <int RM, int RN>
-    NOINLINE void gemm(int m0, int m, int n0, int n) {
-        int ytiles = (m - m0) / RM;
-        int xtiles = (n - n0) / RN;
-        int tiles = xtiles * ytiles;
-        int duty = (tiles + nth - 1) / nth;
-        int start = duty * ith;
-        int end = start + duty;
+    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
         if (end > tiles)
             end = tiles;
-        for (int job = start; job < end; ++job) {
-            int ii = m0 + job / xtiles * RM;
-            int jj = n0 + job % xtiles * RN;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
             D Cv[RN][RM] = {};
-            for (int l = 0; l < k; l += KN)
-                for (int j = 0; j < RN; ++j)
-                    for (int i = 0; i < RM; ++i)
+            for (int64_t l = 0; l < k; l += KN)
+                for (int64_t j = 0; j < RN; ++j)
+                    for (int64_t i = 0; i < RM; ++i)
                         Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
                                         load<V>(B + ldb * (jj + j) + l),
                                         Cv[j][i]);
-            for (int j = 0; j < RN; ++j)
-                for (int i = 0; i < RM; ++i)
+            for (int64_t j = 0; j < RN; ++j)
+                for (int64_t i = 0; i < RM; ++i)
                     C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
         }
     }
@@ -437,10 +436,10 @@ class tinyBLAS {
     const TA *const A;
     const TB *const B;
     TC *const C;
-    const int k;
-    const int lda;
-    const int ldb;
-    const int ldc;
+    const int64_t k;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
     const int ith;
     const int nth;
 };
@@ -452,23 +451,23 @@ class tinyBLAS {
 template <typename TA>
 class tinyBLAS_Q0_ARM {
   public:
-    tinyBLAS_Q0_ARM(int k,
-                    const TA *A, int lda,
-                    const block_q8_0 *B, int ldb,
-                    float *C, int ldc,
+    tinyBLAS_Q0_ARM(int64_t k,
+                    const TA *A, int64_t lda,
+                    const block_q8_0 *B, int64_t ldb,
+                    float *C, int64_t ldc,
                     int ith, int nth)
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
-    void matmul(int m, int n, int task) {
+    void matmul(int64_t m, int64_t n, int task) {
         if (task == GGML_TASK_TYPE_COMPUTE)
             mnpack(0, m, 0, n);
     }
 
   private:
-    NOINLINE void mnpack(int m0, int m, int n0, int n) {
-        int mc, nc, mp, np;
-        switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) {
+    NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t mc, nc, mp, np;
+        switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
         case 0x33:
             mc = 3;
             nc = 3;
@@ -524,22 +523,22 @@ class tinyBLAS_Q0_ARM {
     }
 
     template <int RM, int RN>
-    NOINLINE void gemm(int m0, int m, int n0, int n) {
-        int ytiles = (m - m0) / RM;
-        int xtiles = (n - n0) / RN;
-        int tiles = xtiles * ytiles;
-        int duty = (tiles + nth - 1) / nth;
-        int start = duty * ith;
-        int end = start + duty;
+    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
         if (end > tiles)
             end = tiles;
-        for (int job = start; job < end; ++job) {
-            int ii = m0 + job / xtiles * RM;
-            int jj = n0 + job % xtiles * RN;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
             float32x4_t Cv[RN][RM] = {};
-            for (int l = 0; l < k; ++l)
-                for (int j = 0; j < RN; ++j)
-                    for (int i = 0; i < RM; ++i)
+            for (int64_t l = 0; l < k; ++l)
+                for (int64_t j = 0; j < RN; ++j)
+                    for (int64_t i = 0; i < RM; ++i)
                         Cv[j][i] = vmlaq_n_f32(Cv[j][i],
                                                vcvtq_f32_s32(vdotq_s32(
                                                    vdotq_s32(vdupq_n_s32(0),
@@ -549,8 +548,8 @@ class tinyBLAS_Q0_ARM {
                                                    load_hi(B + ldb * (jj + j) + l))),
                                                unhalf(A[lda * (ii + i) + l].d) *
                                                unhalf(B[ldb * (jj + j) + l].d));
-            for (int j = 0; j < RN; ++j)
-                for (int i = 0; i < RM; ++i)
+            for (int64_t j = 0; j < RN; ++j)
+                for (int64_t i = 0; i < RM; ++i)
                     C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
         }
     }
@@ -577,10 +576,10 @@ class tinyBLAS_Q0_ARM {
     const TA *const A;
     const block_q8_0 *const B;
     float *const C;
-    const int k;
-    const int lda;
-    const int ldb;
-    const int ldc;
+    const int64_t k;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
     const int ith;
     const int nth;
 };
@@ -590,23 +589,23 @@ class tinyBLAS_Q0_ARM {
 template <typename TA, typename TB, typename TC>
 class tinyBLAS_Q0_AVX2 {
   public:
-    tinyBLAS_Q0_AVX2(int k,
-                     const TA *A, int lda,
-                     const TB *B, int ldb,
-                     TC *C, int ldc,
+    tinyBLAS_Q0_AVX2(int64_t k,
+                     const TA *A, int64_t lda,
+                     const TB *B, int64_t ldb,
+                     TC *C, int64_t ldc,
                      int ith, int nth)
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
-    void matmul(int m, int n, int task) {
+    void matmul(int64_t m, int64_t n, int task) {
         if (task == GGML_TASK_TYPE_COMPUTE)
             mnpack(0, m, 0, n);
     }
 
   private:
-    void mnpack(int m0, int m, int n0, int n) {
-        int mc, nc, mp, np;
-        switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) {
+    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t mc, nc, mp, np;
+        switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
 #if VECTOR_REGISTERS == 32
         case 0x44:
             mc = 4;
@@ -714,22 +713,22 @@ class tinyBLAS_Q0_AVX2 {
     }
 
     template <int RM, int RN>
-    NOINLINE void gemm(int m0, int m, int n0, int n) {
-        int ytiles = (m - m0) / RM;
-        int xtiles = (n - n0) / RN;
-        int tiles = xtiles * ytiles;
-        int duty = (tiles + nth - 1) / nth;
-        int start = duty * ith;
-        int end = start + duty;
+    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
         if (end > tiles)
             end = tiles;
-        for (int job = start; job < end; ++job) {
-            int ii = m0 + job / xtiles * RM;
-            int jj = n0 + job % xtiles * RN;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
             __m256 Cv[RN][RM] = {};
-            for (int l = 0; l < k; ++l)
-                for (int j = 0; j < RN; ++j)
-                    for (int i = 0; i < RM; ++i)
+            for (int64_t l = 0; l < k; ++l)
+                for (int64_t j = 0; j < RN; ++j)
+                    for (int64_t i = 0; i < RM; ++i)
                         Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
                                                        unhalf(B[ldb * (jj + j) + l].d)),
                                         updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
@@ -737,8 +736,8 @@ class tinyBLAS_Q0_AVX2 {
                                               _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
                                                                load(A + lda * (ii + i) + l))),
                                         Cv[j][i]);
-            for (int j = 0; j < RN; ++j)
-                for (int i = 0; i < RM; ++i)
+            for (int64_t j = 0; j < RN; ++j)
+                for (int64_t i = 0; i < RM; ++i)
                     C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
         }
     }
@@ -771,10 +770,10 @@ class tinyBLAS_Q0_AVX2 {
     const TA *const A;
     const TB *const B;
     TC *const C;
-    const int k;
-    const int lda;
-    const int ldb;
-    const int ldc;
+    const int64_t k;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
     const int ith;
     const int nth;
 };
@@ -813,8 +812,8 @@ class tinyBLAS_Q0_AVX2 {
  * @param Ctype is GGML data type of `C`
  * @return true if this function was able to service the matmul request
  */
-bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C,
-                     int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
+bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
+                     int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
 
     assert(m >= 0);
     assert(n >= 0);
@@ -824,9 +823,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
     assert(ldc >= m);
     assert(nth > 0);
     assert(ith < nth);
-    assert(1ll * lda * m <= 0x7fffffff);
-    assert(1ll * ldb * n <= 0x7fffffff);
-    assert(1ll * ldc * n <= 0x7fffffff);
 
     if (Ctype != GGML_TYPE_F32)
         return false;
diff --git a/sgemm.h b/sgemm.h
index da23b209c4dd5b2b9627f70aa12d22a8ae9772ac..f29747d0a477af4ef5fb24f3877ac30ffd0ccae8 100644 (file)
--- a/sgemm.h
+++ b/sgemm.h
@@ -1,11 +1,13 @@
 #pragma once
+#include <stdint.h>
 #include <stdbool.h>
 #ifdef __cplusplus
 extern "C" {
 #endif
 
-bool llamafile_sgemm(int, int, int, const void *, int, const void *, int,
-                     void *, int, int, int, int, int, int, int);
+bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
+                     const void *, int64_t, void *, int64_t, int, int,
+                     int, int, int, int);
 
 #ifdef __cplusplus
 }