-// Defines CLOCK_MONOTONIC and asprintf on Linux
+// Defines CLOCK_MONOTONIC on Linux
#define _GNU_SOURCE
#include "ggml.h"
#define static_assert(cond, msg) struct global_scope_noop_trick
#endif
-#if defined _MSC_VER || defined(__MINGW32__)
+#if defined(_WIN32)
-#if !defined(__MINGW32__)
-#include <Windows.h>
-#else
-// ref: https://github.com/ggerganov/whisper.cpp/issues/168
#include <windows.h>
-#endif
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
typedef DWORD thread_ret_t;
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
+ (void) unused;
HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
if (handle == NULL)
{
}
static int pthread_join(pthread_t thread, void* unused) {
+ (void) unused;
return (int) WaitForSingleObject(thread, INFINITE);
}
#define GGML_MEM_ALIGN 16
#endif
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
+#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
+#else
+#define GGML_ALIGNED_MALLOC(size) aligned_alloc(GGML_MEM_ALIGN, size)
+#define GGML_ALIGNED_FREE(ptr) free(ptr)
+#endif
+
#define UNUSED(x) (void)(x)
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
}
static inline uint32_t fp32_to_bits(float f) {
- union {
- float as_value;
- uint32_t as_bits;
- } fp32;
- fp32.as_value = f;
- return fp32.as_bits;
+ union {
+ float as_value;
+ uint32_t as_bits;
+ } fp32;
+ fp32.as_value = f;
+ return fp32.as_bits;
}
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
}
#endif
+#if __ARM_NEON
+
+#if !defined(__aarch64__)
+
+inline static uint16_t vaddvq_u8(uint8x16_t v) {
+ return
+ (uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
+ (uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
+ (uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
+ (uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
+ (uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
+ (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
+ (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
+ (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
+}
+
+inline static int32_t vaddvq_s16(int16x8_t v) {
+ return
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
+}
+
+inline static uint32_t vaddvq_u16(uint16x8_t v) {
+ return
+ (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
+ (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
+ (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
+ (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
+}
+
+inline static int32_t vaddvq_s32(int32x4_t v) {
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
+}
+
+inline static float vaddvq_f32(float32x4_t v) {
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
+}
+
+inline float vminvq_f32(float32x4_t v) {
+ return
+ MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+ MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline float vmaxvq_f32(float32x4_t v) {
+ return
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
+ return vget_low_s8(vcombine_s8(a, b));
+}
+
+inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
+ return vget_high_s8(vcombine_s8(a, b));
+}
+
+inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
+ return vget_low_u8(vcombine_u8(a, b));
+}
+
+inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
+ return vget_high_u8(vcombine_u8(a, b));
+}
+
+#endif
+#endif
+
// method 5
// blocks of QK elements
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
#define GGML_F32x4_ADD vaddq_f32
#define GGML_F32x4_MUL vmulq_f32
-#if defined(__ARM_FEATURE_QRDMX)
- #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
-#else
- #define GGML_F32x4_REDUCE_ONE(x) \
- (vgetq_lane_f32(x, 0) + \
- vgetq_lane_f32(x, 1) + \
- vgetq_lane_f32(x, 2) + \
- vgetq_lane_f32(x, 3))
-#endif
+#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
#define GGML_F32x4_REDUCE(res, x) \
{ \
for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
-
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
-
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
-
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
-
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
#if defined(__ARM_FEATURE_DOTPROD)
- // dot product into int16x8_t
+ // dot product into int32x4_t
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
- // scalar
-#if defined(__ARM_FEATURE_QRDMX)
- sum0 += x0->d * y0->d * vaddvq_s32(p_0);
- sum1 += x1->d * y1->d * vaddvq_s32(p_1);
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
#else
- sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
- sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
-#endif
-#else
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
- // scalar
-#if defined(__ARM_FEATURE_QRDMX)
- sum0 += x0->d * y0->d * vaddvq_s16(p_0);
- sum1 += x1->d * y1->d * vaddvq_s16(p_1);
-#else
- sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
- sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
-#endif
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
#endif
}
const uint8_t * restrict p0 = x[i].qs;
const uint8_t * restrict p1 = y[i].qs;
+ int sumi = 0;
for (int j = 0; j < QK/2; j++) {
const uint8_t v0 = p0[j];
const uint8_t v1 = p1[j];
- const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
- const float f1 = d0*((int8_t) (v0 >> 4) - 8);
+ const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
+ const int8_t i1 = (int8_t) (v0 >> 4) - 8;
- const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
- const float f3 = d1*((int8_t) (v1 >> 4) - 8);
+ const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
+ const int8_t i3 = (int8_t) (v1 >> 4) - 8;
- sumf += f0*f2 + f1*f3;
+ sumi += i0*i2 + i1*i3;
}
+ sumf += d0 * d1 * sumi;
}
#endif
float sum10 = 0.0f;
float sum11 = 0.0f;
- for (int i = 0; i < nb; ++i) {
+ for (int i = 0; i < nb; i += 2) {
const block_q4_1 * restrict x0 = &x[i + 0];
const block_q4_1 * restrict y0 = &y[i + 0];
+ const block_q4_1 * restrict x1 = &x[i + 1];
+ const block_q4_1 * restrict y1 = &y[i + 1];
const uint8x16_t m4b = vdupq_n_u8(0xf);
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+ const uint8x16_t v1_1 = vld1q_u8(y1->qs);
- // and with 0xf
+ // 4-bit -> 8-bit
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
-
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
- // dot product into uint16x8_t
+ const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
+ const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
+ const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
+ const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+
+ sum00 += x0->m*y0->m;
+ sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
+ sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
+
+ sum00 += x1->m*y1->m;
+ sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
+ sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ // dot product into int32x4_t
+ int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l);
+ int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l);
+
+ p_0 = vdotq_s32(p_0, v0_0h, v1_0h);
+ p_1 = vdotq_s32(p_1, v0_1h, v1_1h);
+
+ sum11 += x0->d*y0->d*vaddvq_s32(p_0);
+ sum11 += x1->d*y1->d*vaddvq_s32(p_1);
+#else
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
-
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
- const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
- const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
+ const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
+ const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
+ const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
+ const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
- sum00 += x0->m*y0->m;
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
- sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
+ const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
+ const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
+
+ const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
+ const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
+
+ const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
+ const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
+
+ sum11 += x0->d*y0->d*vaddvq_u16(p_0);
+ sum11 += x1->d*y1->d*vaddvq_u16(p_1);
+#endif
}
sumf = QK*sum00 + sum01 + sum10 + sum11;
//
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
- QK,
- QK,
- 1,
- 1,
- 1,
- 1,
- 1,
+ [GGML_TYPE_F32] = 1,
+ [GGML_TYPE_F16] = 1,
+ [GGML_TYPE_Q4_0] = QK,
+ [GGML_TYPE_Q4_1] = QK,
+ [GGML_TYPE_I8] = 1,
+ [GGML_TYPE_I16] = 1,
+ [GGML_TYPE_I32] = 1,
};
-
-static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
+static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
- sizeof(block_q4_0),
- sizeof(block_q4_1),
- sizeof(int8_t ),
- sizeof(int16_t),
- sizeof(int32_t),
- sizeof(ggml_fp16_t),
- sizeof(float ),
+ [GGML_TYPE_F32] = sizeof(float),
+ [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
+ [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
+ [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
+ [GGML_TYPE_I8] = sizeof(int8_t),
+ [GGML_TYPE_I16] = sizeof(int16_t),
+ [GGML_TYPE_I32] = sizeof(int32_t),
};
-
-// don't forget to update the array above when adding new types
-static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
+static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"NONE",
*ctx = (struct ggml_context) {
/*.mem_size =*/ params.mem_size,
- /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
+ /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(params.mem_size),
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
/*.no_alloc =*/ params.no_alloc,
/*.n_objects =*/ 0,
__func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
if (ctx->mem_buffer_owned) {
- free(ctx->mem_buffer);
+ GGML_ALIGNED_FREE(ctx->mem_buffer);
}
found = true;
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne11, ne01, ne10,
1.0f, y, ne10,
- x, ne10,
+ x, ne00,
0.0f, d, ne01);
}
}
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne11, ne01, ne10,
1.0f, y, ne10,
- x, ne10,
+ x, ne00,
0.0f, d, ne01);
}
}
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne11, ne01, ne10,
1.0f, y, ne10,
- x, ne10,
+ x, ne00,
0.0f, d, ne01);
}
}
struct ggml_cgraph result = {
/*.n_nodes =*/ 0,
/*.n_leafs =*/ 0,
- /*.n_threads =*/ 0,
+ /*.n_threads =*/ GGML_DEFAULT_N_THREADS,
/*.work_size =*/ 0,
/*.work =*/ NULL,
/*.nodes =*/ { NULL },
GGML_PRINT("=== GRAPH ===\n");
- GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
- GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
+ GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
+ GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size);
GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
for (int i = 0; i < cgraph->n_nodes; i++) {