]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
backend cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels (llama/9921)
authorCharles Xu <redacted>
Fri, 15 Nov 2024 00:28:50 +0000 (01:28 +0100)
committerGeorgi Gerganov <redacted>
Wed, 20 Nov 2024 19:00:08 +0000 (21:00 +0200)
* backend-cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels

---------

Co-authored-by: Diego Devesa <redacted>
ggml/CMakeLists.txt
ggml/include/ggml-cpu.h
ggml/src/ggml-cpu/CMakeLists.txt
ggml/src/ggml-cpu/ggml-cpu-aarch64.c
ggml/src/ggml-cpu/ggml-cpu-aarch64.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ggml-cpu.cpp

index 154f82fe7ddc371932b6bcd32eb322d718a998a4..c02bfb59d85698c898a32b2f0e74baf43516a6be 100644 (file)
@@ -92,6 +92,7 @@ else()
 endif()
 
 option(GGML_CPU_HBM     "ggml: use memkind for CPU HBM" OFF)
+option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
 
 option(GGML_AVX         "ggml: enable AVX"              ${INS_ENB})
 option(GGML_AVX2        "ggml: enable AVX2"             ${INS_ENB})
index 4da62cb2b63f3dacb363b8a693d96065f0ef4c55..7571ef9798364854d7fbd5b38ff6ebe6af0dfd29 100644 (file)
@@ -169,6 +169,9 @@ extern "C" {
     GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
 #endif
 
+    GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void);
+    GGML_BACKEND_API bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft);
+
 #ifdef __cplusplus
 }
 #endif
index 4d96f425e3c6e587a69811b2a2b4c54a102073f0..8b0d60d4ec7b9ae7d9617a026a30123a65130987 100644 (file)
@@ -236,6 +236,11 @@ else()
     message(STATUS "Unknown architecture")
 endif()
 
+if (GGML_CPU_AARCH64)
+    message(STATUS "Using runtime weight conversion of Q4_0 to Q4_0_x_x to enable optimized GEMM/GEMV kernels")
+    add_compile_definitions(GGML_USE_CPU_AARCH64)
+endif()
+
 target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
 target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
 
index 0ad9fe40a3e0ac58737d34ea900ba781a0f2fc47..b753ba767c15ad5415029b9c07a876272ee5d000 100644 (file)
@@ -3385,3 +3385,147 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
         }
     }
 }
+
+// FIXME: this code is duplicated from ggml-aarch64.c
+static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
+    block_q4_0x4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    for (int i = 0; i < QK4_0 * 2; i++) {
+        int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave;
+        int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave;
+        src_offset += (i % blck_size_interleave);
+
+        out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
+    }
+
+    return out;
+}
+
+// interleave 8 block_q4_0s in blocks of blck_size_interleave
+// returns an interleaved block_q4_0x8
+// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
+// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
+static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
+    block_q4_0x8 out;
+
+    for (int i = 0; i < 8; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    for (int i = 0; i < QK4_0 * 4; i++) {
+        int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave;
+        int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave;
+        src_offset += (i % blck_size_interleave);
+
+        out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
+    }
+
+    return out;
+}
+
+static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
+    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+
+    block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
+    const block_q4_0 * src = (const block_q4_0 *)data;
+    block_q4_0 dst_tmp[4];
+    int nrow = t->ne[1]; // Number of rows
+    int nrows_interleaved = 4;
+    int nblocks = t->ne[0] / QK4_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
+
+    if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q4_0x4(dst_tmp, interleave_block, 0x88);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
+    GGML_ASSERT(interleave_block == 8);
+
+    block_q4_0x8 * dst = (block_q4_0x8*)t->data;
+    const block_q4_0 * src = (const block_q4_0*) data;
+    block_q4_0 dst_tmp[8];
+    int nrow = t->ne[1]; // Number of rows
+    int nrows_interleaved = 8;
+    int nblocks = t->ne[0] / QK4_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
+
+    if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i  = 0; i < nrows_interleaved; i++ ) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q4_0x8(dst_tmp, interleave_block, 0x88);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+// Prepare for optimized kernels if applicable
+void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * restrict data, size_t data_size) {
+    if (cur->type == repack_type) {
+        memcpy(cur->data, data, data_size);
+        return;
+    }
+
+    GGML_ASSERT(cur->type == GGML_TYPE_Q4_0);
+
+    switch (repack_type) {
+        case GGML_TYPE_Q4_0_8_8:
+            repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
+            break;
+        case GGML_TYPE_Q4_0_4_8:
+            repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
+            break;
+        case GGML_TYPE_Q4_0_4_4:
+            repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
+            break;
+        default:
+            GGML_ABORT("Unsupported type");
+    }
+}
+
+enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
+    if (cur->type == GGML_TYPE_Q4_0) {
+        // TODO: enable for AVX2 - currently disabled due to bad gemv performance
+        if (/* ggml_cpu_has_avx2() || */ (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
+            return GGML_TYPE_Q4_0_8_8;
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+            return GGML_TYPE_Q4_0_4_8;
+        }
+        if (ggml_cpu_has_neon()) {
+            return GGML_TYPE_Q4_0_4_4;
+        }
+    }
+
+    return cur->type;
+}
index 203802f07320cab7fe584983a40e28ef39e5e752..53b30c1dd2dfea0b2f34a2039d025f0fb608d3c0 100644 (file)
@@ -21,6 +21,9 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
 void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 
+void           ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * data, size_t data_size);
+enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur);
+
 #ifdef __cplusplus
 }
 #endif
index 4c45146a1f0f3d2924154f1c2f880e50244efcea..30b1bf895720e8874127d6968253f6e8effb86a2 100644 (file)
@@ -7330,6 +7330,7 @@ static void ggml_compute_forward_group_norm(
 static void ggml_compute_forward_mul_mat_one_chunk(
     const struct ggml_compute_params * params,
     struct ggml_tensor * dst,
+    const enum ggml_type type,
     const int64_t num_rows_per_vec_dot,
     const int64_t ir0_start,
     const int64_t ir0_end,
@@ -7341,8 +7342,6 @@ static void ggml_compute_forward_mul_mat_one_chunk(
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
-    const enum ggml_type type = src0->type;
-
     const bool src1_cont = ggml_is_contiguous(src1);
 
     ggml_vec_dot_t const vec_dot      = type_traits_cpu[type].vec_dot;
@@ -7430,7 +7429,11 @@ static void ggml_compute_forward_mul_mat(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const enum ggml_type type = src0->type;
+    enum ggml_type type = src0->type;
+
+    if (src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
+        type = (enum ggml_type)(intptr_t)src0->extra;
+    }
 
     enum ggml_type           const vec_dot_type         = type_traits_cpu[type].vec_dot_type;
     ggml_from_float_t        const from_float           = type_traits_cpu[vec_dot_type].from_float;
@@ -7469,15 +7472,15 @@ static void ggml_compute_forward_mul_mat(
     if (src1_cont) {
         for (int64_t i13 = 0; i13 < ne13; i13++)
             for (int64_t i12 = 0; i12 < ne12; i12++)
-                if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
+                if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
                                      (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
-                                     nb01/ggml_type_size(src0->type),
+                                     nb01/ggml_type_size(type),
                                      (const char *)src1->data + i12*nb12 + i13*nb13,
                                      nb11/ggml_type_size(src1->type),
                                      (char *)dst->data + i12*nb2 + i13*nb3,
                                      nb1/ggml_type_size(dst->type),
                                      ith, nth,
-                                     src0->type,
+                                     type,
                                      src1->type,
                                      dst->type))
                     goto UseGgmlGemm1;
@@ -7530,15 +7533,15 @@ UseGgmlGemm1:;
 
         for (int64_t i13 = 0; i13 < ne13; i13++)
             for (int64_t i12 = 0; i12 < ne12; i12++)
-                if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
+                if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
                                      (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
-                                     nb01/ggml_type_size(src0->type),
+                                     nb01/ggml_type_size(type),
                                      (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
                                      row_size/ggml_type_size(vec_dot_type),
                                      (char *)dst->data + i12*nb2 + i13*nb3,
                                      nb1/ggml_type_size(dst->type),
                                      ith, nth,
-                                     src0->type,
+                                     type,
                                      vec_dot_type,
                                      dst->type))
                     goto UseGgmlGemm2;
@@ -7623,7 +7626,7 @@ UseGgmlGemm2:;
         const int64_t ir1_start = dr1 * ith1;
         const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
 
-        ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
+        ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
 
         if (nth >= nchunk0 * nchunk1) {
             break;
index c7216117bc80502caadca726b74f3bcf20c0686a..573b7c5b9b375bdabd4c66ef35c391225d011f38 100644 (file)
@@ -1,6 +1,7 @@
 #include "ggml-backend.h"
 #include "ggml-backend-impl.h"
 #include "ggml-cpu.h"
+#include "ggml-cpu-aarch64.h"
 #include "ggml-impl.h"
 #include <cctype>
 #include <string>
@@ -69,15 +70,84 @@ ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
 }
 #endif
 
+// buffer type AARCH64
+
+static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    tensor->extra = (void *)ggml_aarch64_get_optimal_repack_type(tensor); // NOLINT
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset == 0);
+    GGML_ASSERT(size == ggml_nbytes(tensor));
+
+    enum ggml_type repack_type = (enum ggml_type)(intptr_t)tensor->extra;
+
+    ggml_aarch64_repack_tensor(tensor, repack_type, data, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "CPU_AARCH64";
+
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    auto * buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+
+    if (buffer == NULL) {
+        return NULL;
+    }
+
+    buffer->buft = buft;
+    buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
+    buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor;
+
+    return buffer;
+}
+
+ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = {
+        /* .iface    = */ {
+            /* .get_name         = */ ggml_backend_cpu_aarch64_buffer_type_get_name,
+            /* .alloc_buffer     = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
+            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .is_host          = */ NULL,
+        },
+        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_cpu_buffer_type_aarch64;
+}
+
+bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft) {
+    return buft == ggml_backend_cpu_aarch64_buffer_type();
+}
+
 static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backend_dev_t device) {
-    static ggml_backend_buffer_type_t bufts[] = {
+    static std::vector<ggml_backend_buffer_type_t> bufts = []() {
+        std::vector<ggml_backend_buffer_type_t> bufts;
+
 #ifdef GGML_USE_CPU_HBM
-        ggml_backend_cpu_hbm_buffer_type(),
+        bufts.push_back(ggml_backend_cpu_hbm_buffer_type());
+#endif
+
+#ifdef GGML_USE_CPU_AARCH64
+        bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
 #endif
-        NULL
-    };
 
-    return bufts;
+        bufts.push_back(NULL);
+
+        return bufts;
+    }();
+
+    return bufts.data();
 
     GGML_UNUSED(device);
 }
@@ -383,6 +453,21 @@ static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_host_ptr(ggml_b
 }
 
 static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+
+    if (src0 && src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
+        if (op->op != GGML_OP_MUL_MAT || src0->type != GGML_TYPE_Q4_0 || ggml_aarch64_get_optimal_repack_type(src0) == GGML_TYPE_Q4_0) {
+            return false;
+        }
+    }
+
+    for (int i = 1; i < GGML_MAX_SRC; i++) {
+        if (op->src[i] && op->src[i]->buffer && ggml_backend_cpu_buft_is_aarch64(op->src[i]->buffer->buft)) {
+            return false;
+        }
+    }
+
     switch (op->op) {
         case GGML_OP_CPY:
             return
@@ -391,13 +476,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
                 op->type != GGML_TYPE_IQ1_S   &&
                 op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
         case GGML_OP_MUL_MAT:
-            return op->src[1]->type == GGML_TYPE_F32;// FIXME || op->src[1]->type == ggml_get_type_traits(op->src[0]->type)->vec_dot_type;
+            return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
         case GGML_OP_ROPE_BACK:
             return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
         case GGML_OP_IM2COL_BACK:
-            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
+            return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
         case GGML_OP_OUT_PROD:
-            return (op->src[0]->type == GGML_TYPE_F32 || ggml_is_quantized(op->src[0]->type)) && op->src[1]->type == GGML_TYPE_F32;
+            return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
         default:
             return true;
     }
@@ -406,7 +491,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
 }
 
 static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    return ggml_backend_buft_is_host(buft);
+    return ggml_backend_buft_is_host(buft) || ggml_backend_cpu_buft_is_aarch64(buft);
 
     GGML_UNUSED(dev);
 }
@@ -566,6 +651,9 @@ static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = {
 };
 
 ggml_backend_reg_t ggml_backend_cpu_reg(void) {
+    // init CPU feature detection
+    ggml_cpu_init();
+
     static struct ggml_backend_reg ggml_backend_cpu_reg = {
         /* .iface   = */ ggml_backend_cpu_reg_i,
         /* .context = */ NULL,