]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : sync latest changes from ggml and llama.cpp
authorGeorgi Gerganov <redacted>
Thu, 13 Apr 2023 15:53:44 +0000 (18:53 +0300)
committerGeorgi Gerganov <redacted>
Thu, 13 Apr 2023 15:53:44 +0000 (18:53 +0300)
Makefile
ggml.c
ggml.h

index ecf07cacfc37880655c7e80b79e291d69dd17273..438e3f4fc46c81fcba634ff804eee981de636a91 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -157,7 +157,7 @@ endif
 ifneq ($(filter armv7%,$(UNAME_M)),)
        # 32-bit ARM, for example on Armbian or possibly raspbian
        CFLAGS += -mfpu=neon -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
-       
+
        # 64-bit ARM, use these (TODO: auto-detect 64-bit)
        # CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
 endif
@@ -190,7 +190,7 @@ default: main bench
 ggml.o: ggml.c ggml.h
        $(CC)  $(CFLAGS)   -c ggml.c -o ggml.o
 
-whisper.o: whisper.cpp whisper.h
+whisper.o: whisper.cpp whisper.h ggml.h
        $(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
 
 libwhisper.a: ggml.o whisper.o
diff --git a/ggml.c b/ggml.c
index 71dbb3bb3bd0fb47c919806f96078cab7e9fb6f8..42e3ee314424d5f8dada77990a03c79ef24b6c58 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1,4 +1,4 @@
-// Defines CLOCK_MONOTONIC and asprintf on Linux
+// Defines CLOCK_MONOTONIC on Linux
 #define _GNU_SOURCE
 
 #include "ggml.h"
 #define static_assert(cond, msg) struct global_scope_noop_trick
 #endif
 
-#if defined _MSC_VER || defined(__MINGW32__)
+#if defined(_WIN32)
 
-#if !defined(__MINGW32__)
-#include <Windows.h>
-#else
-// ref: https://github.com/ggerganov/whisper.cpp/issues/168
 #include <windows.h>
-#endif
 
 typedef volatile LONG atomic_int;
 typedef atomic_int atomic_bool;
@@ -55,6 +50,7 @@ typedef HANDLE pthread_t;
 
 typedef DWORD thread_ret_t;
 static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
+    (void) unused;
     HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
     if (handle == NULL)
     {
@@ -66,6 +62,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
 }
 
 static int pthread_join(pthread_t thread, void* unused) {
+    (void) unused;
     return (int) WaitForSingleObject(thread, INFINITE);
 }
 
@@ -117,6 +114,14 @@ typedef void* thread_ret_t;
     #define GGML_MEM_ALIGN 16
 #endif
 
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#define GGML_ALIGNED_MALLOC(size)  _aligned_malloc(size, GGML_MEM_ALIGN)
+#define GGML_ALIGNED_FREE(ptr)     _aligned_free(ptr)
+#else
+#define GGML_ALIGNED_MALLOC(size)  aligned_alloc(GGML_MEM_ALIGN, size)
+#define GGML_ALIGNED_FREE(ptr)     free(ptr)
+#endif
+
 #define UNUSED(x) (void)(x)
 #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
 
@@ -231,12 +236,12 @@ static inline float fp32_from_bits(uint32_t w) {
 }
 
 static inline uint32_t fp32_to_bits(float f) {
-       union {
-               float as_value;
-               uint32_t as_bits;
-       } fp32;
-       fp32.as_value = f;
-       return fp32.as_bits;
+    union {
+        float as_value;
+        uint32_t as_bits;
+    } fp32;
+    fp32.as_value = f;
+    return fp32.as_bits;
 }
 
 static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
@@ -486,6 +491,77 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
 }
 #endif
 
+#if __ARM_NEON
+
+#if !defined(__aarch64__)
+
+inline static uint16_t vaddvq_u8(uint8x16_t v) {
+    return
+        (uint16_t)vgetq_lane_u8(v, 0)  + (uint16_t)vgetq_lane_u8(v, 1)  +
+        (uint16_t)vgetq_lane_u8(v, 2)  + (uint16_t)vgetq_lane_u8(v, 3)  +
+        (uint16_t)vgetq_lane_u8(v, 4)  + (uint16_t)vgetq_lane_u8(v, 5)  +
+        (uint16_t)vgetq_lane_u8(v, 6)  + (uint16_t)vgetq_lane_u8(v, 7)  +
+        (uint16_t)vgetq_lane_u8(v, 8)  + (uint16_t)vgetq_lane_u8(v, 9)  +
+        (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
+        (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
+        (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
+}
+
+inline static int32_t vaddvq_s16(int16x8_t v) {
+    return
+        (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
+        (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
+        (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
+        (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
+}
+
+inline static uint32_t vaddvq_u16(uint16x8_t v) {
+    return
+        (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
+        (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
+        (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
+        (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
+}
+
+inline static int32_t vaddvq_s32(int32x4_t v) {
+    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
+}
+
+inline static float vaddvq_f32(float32x4_t v) {
+    return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
+}
+
+inline float vminvq_f32(float32x4_t v) {
+    return
+        MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+            MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline float vmaxvq_f32(float32x4_t v) {
+    return
+        MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+            MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
+    return vget_low_s8(vcombine_s8(a, b));
+}
+
+inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
+    return vget_high_s8(vcombine_s8(a, b));
+}
+
+inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
+    return vget_low_u8(vcombine_u8(a, b));
+}
+
+inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
+    return vget_high_u8(vcombine_u8(a, b));
+}
+
+#endif
+#endif
+
 // method 5
 // blocks of QK elements
 // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -1213,15 +1289,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
 #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
 #define GGML_F32x4_ADD          vaddq_f32
 #define GGML_F32x4_MUL          vmulq_f32
-#if defined(__ARM_FEATURE_QRDMX)
-    #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
-#else
-    #define GGML_F32x4_REDUCE_ONE(x) \
-    (vgetq_lane_f32(x, 0) +          \
-     vgetq_lane_f32(x, 1) +          \
-     vgetq_lane_f32(x, 2) +          \
-     vgetq_lane_f32(x, 3))
-#endif
+#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
 #define GGML_F32x4_REDUCE(res, x)              \
 {                                              \
     for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1844,55 +1912,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
         // 4-bit -> 8-bit
         const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
         const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
-
         const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
         const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
 
         const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
         const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
-
         const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
         const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
 
         // sub 8
         const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
         const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
-
         const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
         const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
 
         const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
         const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
-
         const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
         const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
 
 #if defined(__ARM_FEATURE_DOTPROD)
-        // dot product into int16x8_t
+        // dot product into int32x4_t
         int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
         int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
 
         p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
         p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
 
-        // scalar
-#if defined(__ARM_FEATURE_QRDMX)
-        sum0 += x0->d * y0->d * vaddvq_s32(p_0);
-        sum1 += x1->d * y1->d * vaddvq_s32(p_1);
+        sum0 += x0->d*y0->d*vaddvq_s32(p_0);
+        sum1 += x1->d*y1->d*vaddvq_s32(p_1);
 #else
-        sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
-        sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
-#endif
-#else
-           const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
-
         const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
         const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
 
         const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
         const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
-
         const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
         const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
 
@@ -1905,14 +1961,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
         const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
         const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
 
-        // scalar
-#if defined(__ARM_FEATURE_QRDMX)
-        sum0 += x0->d * y0->d * vaddvq_s16(p_0);
-        sum1 += x1->d * y1->d * vaddvq_s16(p_1);
-#else
-        sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
-        sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
-#endif
+        sum0 += x0->d*y0->d*vaddvq_s16(p_0);
+        sum1 += x1->d*y1->d*vaddvq_s16(p_1);
 #endif
     }
 
@@ -2155,18 +2205,20 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
         const uint8_t * restrict p0 = x[i].qs;
         const uint8_t * restrict p1 = y[i].qs;
 
+        int sumi = 0;
         for (int j = 0; j < QK/2; j++) {
             const uint8_t v0 = p0[j];
             const uint8_t v1 = p1[j];
 
-            const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
-            const float f1 = d0*((int8_t) (v0 >> 4)  - 8);
+            const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
+            const int8_t i1 = (int8_t) (v0 >> 4)  - 8;
 
-            const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
-            const float f3 = d1*((int8_t) (v1 >> 4)  - 8);
+            const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
+            const int8_t i3 = (int8_t) (v1 >> 4)  - 8;
 
-            sumf += f0*f2 + f1*f3;
+            sumi += i0*i2 + i1*i3;
         }
+        sumf += d0 * d1 * sumi;
     }
 #endif
 
@@ -2258,36 +2310,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
     float sum10 = 0.0f;
     float sum11 = 0.0f;
 
-    for (int i = 0; i < nb; ++i) {
+    for (int i = 0; i < nb; i += 2) {
         const block_q4_1 * restrict x0 = &x[i + 0];
         const block_q4_1 * restrict y0 = &y[i + 0];
+        const block_q4_1 * restrict x1 = &x[i + 1];
+        const block_q4_1 * restrict y1 = &y[i + 1];
 
         const uint8x16_t m4b = vdupq_n_u8(0xf);
 
         const uint8x16_t v0_0 = vld1q_u8(x0->qs);
         const uint8x16_t v1_0 = vld1q_u8(y0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+        const uint8x16_t v1_1 = vld1q_u8(y1->qs);
 
-        // and with 0xf
+        // 4-bit -> 8-bit
         const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
         const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
-
         const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
         const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
 
-        // dot product into uint16x8_t
+        const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
+        const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
+        const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
+        const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+
+        sum00 += x0->m*y0->m;
+        sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
+        sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
+
+        sum00 += x1->m*y1->m;
+        sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
+        sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int32x4_t
+        int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l);
+        int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l);
+
+        p_0 = vdotq_s32(p_0, v0_0h, v1_0h);
+        p_1 = vdotq_s32(p_1, v0_1h, v1_1h);
+
+        sum11 += x0->d*y0->d*vaddvq_s32(p_0);
+        sum11 += x1->d*y1->d*vaddvq_s32(p_1);
+#else
         const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
         const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
-
         const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
         const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
 
-        const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
-        const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
+        const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
+        const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
+        const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
+        const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
 
-        sum00 += x0->m*y0->m;
-        sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
-        sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
-        sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
+        const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
+        const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
+
+        const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
+        const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
+
+        const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
+        const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
+
+        sum11 += x0->d*y0->d*vaddvq_u16(p_0);
+        sum11 += x1->d*y1->d*vaddvq_u16(p_1);
+#endif
     }
 
     sumf = QK*sum00 + sum01 + sum10 + sum11;
@@ -2563,29 +2650,26 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
 //
 
 static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
-    QK,
-    QK,
-    1,
-    1,
-    1,
-    1,
-    1,
+    [GGML_TYPE_F32]  = 1,
+    [GGML_TYPE_F16]  = 1,
+    [GGML_TYPE_Q4_0] = QK,
+    [GGML_TYPE_Q4_1] = QK,
+    [GGML_TYPE_I8]   = 1,
+    [GGML_TYPE_I16]  = 1,
+    [GGML_TYPE_I32]  = 1,
 };
-
-static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
+static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
 
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
-    sizeof(block_q4_0),
-    sizeof(block_q4_1),
-    sizeof(int8_t ),
-    sizeof(int16_t),
-    sizeof(int32_t),
-    sizeof(ggml_fp16_t),
-    sizeof(float  ),
+    [GGML_TYPE_F32]  = sizeof(float),
+    [GGML_TYPE_F16]  = sizeof(ggml_fp16_t),
+    [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
+    [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
+    [GGML_TYPE_I8]   = sizeof(int8_t),
+    [GGML_TYPE_I16]  = sizeof(int16_t),
+    [GGML_TYPE_I32]  = sizeof(int32_t),
 };
-
-// don't forget to update the array above when adding new types
-static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
+static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
 
 static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "NONE",
@@ -2972,7 +3056,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 
     *ctx = (struct ggml_context) {
         /*.mem_size           =*/ params.mem_size,
-        /*.mem_buffer         =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
+        /*.mem_buffer         =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(params.mem_size),
         /*.mem_buffer_owned   =*/ params.mem_buffer ? false : true,
         /*.no_alloc           =*/ params.no_alloc,
         /*.n_objects          =*/ 0,
@@ -3007,7 +3091,7 @@ void ggml_free(struct ggml_context * ctx) {
                     __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
 
             if (ctx->mem_buffer_owned) {
-                free(ctx->mem_buffer);
+                GGML_ALIGNED_FREE(ctx->mem_buffer);
             }
 
             found = true;
@@ -6441,7 +6525,7 @@ static void ggml_compute_forward_mul_mat_f32(
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
-                                 x, ne10,
+                                 x, ne00,
                         0.0f,    d, ne01);
             }
         }
@@ -6613,7 +6697,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
-                                 x, ne10,
+                                 x, ne00,
                         0.0f,    d, ne01);
             }
         }
@@ -6826,7 +6910,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
-                                 x, ne10,
+                                 x, ne00,
                         0.0f,    d, ne01);
             }
         }
@@ -9279,7 +9363,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
     struct ggml_cgraph result = {
         /*.n_nodes      =*/ 0,
         /*.n_leafs      =*/ 0,
-        /*.n_threads    =*/ 0,
+        /*.n_threads    =*/ GGML_DEFAULT_N_THREADS,
         /*.work_size    =*/ 0,
         /*.work         =*/ NULL,
         /*.nodes        =*/ { NULL },
@@ -9899,8 +9983,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
 
     GGML_PRINT("=== GRAPH ===\n");
 
-    GGML_PRINT_DEBUG("n_threads       = %d\n",       cgraph->n_threads);
-    GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
+    GGML_PRINT_DEBUG("n_threads       = %d\n",        cgraph->n_threads);
+    GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size);
 
     GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
     for (int i = 0; i < cgraph->n_nodes; i++) {
diff --git a/ggml.h b/ggml.h
index a5245a8ae6256c0bee449cff0d9112e24df152dc..c06c09e060db5ee127465e4b993e839300bb7be4 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -177,11 +177,12 @@ extern "C" {
 #include <stddef.h>
 #include <stdbool.h>
 
-#define GGML_MAX_DIMS     4
-#define GGML_MAX_NODES    4096
-#define GGML_MAX_PARAMS   16
-#define GGML_MAX_CONTEXTS 64
-#define GGML_MAX_OPT      4
+#define GGML_MAX_DIMS          4
+#define GGML_MAX_NODES         4096
+#define GGML_MAX_PARAMS        16
+#define GGML_MAX_CONTEXTS      64
+#define GGML_MAX_OPT           4
+#define GGML_DEFAULT_N_THREADS 4
 
 #ifdef __ARM_NEON
 // we use the built-in 16-bit float type
@@ -198,13 +199,14 @@ struct ggml_object;
 struct ggml_context;
 
 enum ggml_type {
-    GGML_TYPE_Q4_0,
-    GGML_TYPE_Q4_1,
+    // explicitly numbered values are used in llama.cpp files
+    GGML_TYPE_F32  = 0,
+    GGML_TYPE_F16  = 1,
+    GGML_TYPE_Q4_0 = 2,
+    GGML_TYPE_Q4_1 = 3,
     GGML_TYPE_I8,
     GGML_TYPE_I16,
     GGML_TYPE_I32,
-    GGML_TYPE_F16,
-    GGML_TYPE_F32,
     GGML_TYPE_COUNT,
 };