]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
tests : experiments with n-bit quantized matrix multiplication
authorGeorgi Gerganov <redacted>
Thu, 5 Jan 2023 19:05:41 +0000 (21:05 +0200)
committerGeorgi Gerganov <redacted>
Thu, 5 Jan 2023 19:05:41 +0000 (21:05 +0200)
.gitignore
tests/CMakeLists.txt
tests/test-mul-mat1.c
tests/test-mul-mat2.c [new file with mode: 0644]

index 9e331d7ce2715fed6ba49df25dc0c75fa781e08b..bbfa8998e65202840f509ac34134306827f01ef6 100644 (file)
@@ -9,3 +9,4 @@ compile_commands.json
 .DS_Store
 
 src/arm_neon.h
+tests/arm_neon.h
index db66331d17bf9b76777d80f18e8c6c64e09e05ad..a20cf2859441a1205034add1bb142c4e33e2bbaf 100644 (file)
@@ -65,6 +65,14 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" AND NOT GGML_NO_ACCELERATE)
     add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
 endif()
 
+#
+# test-mul-mat2
+
+set(TEST_TARGET test-mul-mat2)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+
 #
 # test0
 
index fb85b420bcdd0ddb6541be4b81843992bd5d7a1c..9a8441623c95a307314d1bdf2c4afaf15075b33c 100644 (file)
@@ -16,6 +16,12 @@ const int M = 1280;
 const int N = 1500;
 const int K = 1280;
 
+uint64_t get_time_us() {
+    struct timeval tv;
+    gettimeofday(&tv, NULL);
+    return tv.tv_sec * 1000000 + tv.tv_usec;
+}
+
 //
 // naive implementation
 //
@@ -206,12 +212,6 @@ void mul_mat_vec_f8_0(
     }
 }
 
-uint64_t get_time_us() {
-    struct timeval tv;
-    gettimeofday(&tv, NULL);
-    return tv.tv_sec * 1000000 + tv.tv_usec;
-}
-
 int main(int argc, const char ** argv) {
     float * src0 = (float *)malloc(sizeof(float)*M*K);
     float * src1 = (float *)malloc(sizeof(float)*N*K);
diff --git a/tests/test-mul-mat2.c b/tests/test-mul-mat2.c
new file mode 100644 (file)
index 0000000..d025782
--- /dev/null
@@ -0,0 +1,272 @@
+// quantized matrix multiplication
+
+#include <float.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <assert.h>
+#include <stdlib.h>
+#include <string.h>
+#include <time.h>
+#include <math.h>
+
+#include <sys/time.h>
+
+#ifdef __ARM_NEON
+#include "arm_neon.h"
+#endif
+
+#ifndef MIN
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+const int M = 1280;
+const int N = 1536;
+const int K = 1280;
+
+const int QK = 64;
+const int QB = 7;
+
+#define gq_t uint64_t
+#define gq_t_bits 64
+
+uint64_t get_time_us() {
+    struct timeval tv;
+    gettimeofday(&tv, NULL);
+    return tv.tv_sec * 1000000 + tv.tv_usec;
+}
+
+//
+// naive implementation
+//
+
+void mul_mat_vec_f32_0(
+    const float * restrict src0, // M x K
+    const float * restrict src1, // N x K (transposed)
+    float * dst,
+    int m, int n, int k) {
+    for (int i = 0; i < m; i++) {
+        for (int j = 0; j < n; j++) {
+            float sum = 0;
+            for (int l = 0; l < k; l++) {
+                sum += src0[i*k + l] * src1[j*k + l];
+            }
+            dst[i*n + j] = sum;
+        }
+    }
+}
+
+void quantize(const float * src, void * dst, int n, int k) {
+    char * p0 = dst;
+
+    for (int j = 0; j < n; j++) {
+        for (int i = 0; i < k/QK; i++) {
+            float min = FLT_MAX;
+            float max = -FLT_MAX;
+
+            // find min/max
+#ifdef __ARM_NEON
+            {
+                float32x4_t minv = vdupq_n_f32(FLT_MAX);
+                float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
+
+                for (int l = 0; l < QK; l += 4) {
+                    float32x4_t v = vld1q_f32(src + j*k + i*QK + l);
+                    minv = vminq_f32(minv, v);
+                    maxv = vmaxq_f32(maxv, v);
+                }
+
+                float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
+                float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
+
+                min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
+                max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
+
+                //printf("SIMD min/max: %f %f\n", min, max);
+            }
+#else
+            {
+                for (int l = 0; l < QK; l++) {
+                    const float v = src[j*k + i*QK + l];
+                    if (v < min) min = v;
+                    if (v > max) max = v;
+                }
+
+                //printf("NORM min/max: %f %f\n", min, max);
+            }
+#endif
+
+            const float d = (max - min) / ((1 << QB) - 1);
+            const float id = d ? 1.0/d : 0.0;
+
+            memcpy(p0, &min, sizeof(float)); p0 += sizeof(float);
+            memcpy(p0, &d,   sizeof(float)); p0 += sizeof(float);
+
+            //printf("min/max/d/id: %f %f %f %f\n", min, max, d, id);
+
+            for (int s = 0; s < QK/gq_t_bits; ++s) {
+                gq_t pp[QB] = {0};
+
+                for (int l = 0; l < gq_t_bits; l++) {
+                    const   float v = src[j*k + i*QK + s*gq_t_bits + l];
+                    const uint8_t q = (v - min)*id;
+
+                    for (int b = 0; b < QB; b++) {
+                        pp[b] |= q & (1 << b) ? (1LL << l) : 0;
+                    }
+                }
+
+                for (int b = 0; b < QB; b++) {
+                    memcpy(p0, &pp[b], sizeof(gq_t)); p0 += sizeof(gq_t);
+                }
+            }
+        }
+    }
+}
+
+void mul_mat_vec_gq_0(
+    const void * src0,
+    const void * src1,
+         float * dst,
+    int m, int n, int k) {
+    const int kp = k & ~(gq_t_bits - 1);
+
+    const char * restrict p0 = src0;
+    const char * restrict p1 = src1;
+
+    for (int ir0 = 0; ir0 < m; ir0++) {
+        for (int ir1 = 0; ir1 < n; ir1++) {
+            float sumf = 0.0;
+
+            const char * restrict pp0 = p0 + ir0*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK));
+            const char * restrict pp1 = p1 + ir1*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK));
+
+            for (int i = 0; i < kp/QK; i++) {
+                float min0, d0;
+                memcpy(&min0, pp0, sizeof(float)); pp0 += sizeof(float);
+                memcpy(&d0,   pp0, sizeof(float)); pp0 += sizeof(float);
+
+                float min1, d1;
+                memcpy(&min1, pp1, sizeof(float)); pp1 += sizeof(float);
+                memcpy(&d1,   pp1, sizeof(float)); pp1 += sizeof(float);
+
+                //printf("min0/d0 = %f %f | min1/d1 = %f %f\n", min0, d0, min1, d1);
+
+#if 1
+                // >>> General case for any QB
+
+                float s0[QB + 1];
+                float s1[QB + 1];
+
+                s0[0] = min0;
+                s1[0] = min1;
+
+                for (int b = 0; b < QB; b++) {
+                    s0[b + 1] = d0*(1 << b);
+                    s1[b + 1] = d1*(1 << b);
+                }
+
+                gq_t m0[QB + 1];
+                gq_t m1[QB + 1];
+
+                m0[0] = -1LL;
+                m1[0] = -1LL;
+
+                for (int s = 0; s < QK/gq_t_bits; ++s) {
+                    for (int b = 0; b < QB; b++) {
+                        memcpy(&m0[b + 1], pp0, sizeof(gq_t)); pp0 += sizeof(gq_t);
+                        memcpy(&m1[b + 1], pp1, sizeof(gq_t)); pp1 += sizeof(gq_t);
+                    }
+
+                    for (int q0 = 0; q0 < QB + 1; q0++) {
+                        for (int q1 = 0; q1 < QB + 1; q1++) {
+                            sumf += s0[q0]*s1[q1]*__builtin_popcountll(m0[q0] & m1[q1]);
+                        }
+                    }
+                }
+#else
+#endif
+            }
+
+            dst[ir0*n + ir1] = sumf;
+        }
+    }
+}
+
+int main(int argc, const char ** argv) {
+    float * src0 = (float *)malloc(sizeof(float)*M*K);
+    float * src1 = (float *)malloc(sizeof(float)*N*K);
+    float * dst  = (float *)malloc(sizeof(float)*M*N);
+
+    for (int i = 0; i < M*K; i++) {
+        src0[i] = rand() / (float)RAND_MAX;
+    }
+
+    for (int i = 0; i < N*K; i++) {
+        src1[i] = rand() / (float)RAND_MAX;
+    }
+
+    void * src0_gq = calloc(1, (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*M);
+    void * src1_gq = calloc(1, (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*N);
+
+    const size_t sizef16 = sizeof(__fp16)*M*K + sizeof(__fp16)*N*K;
+    const size_t sizegq  = (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*M +
+                           (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*N;
+
+    printf("compression: %f\n", (float)sizegq/sizef16);
+
+    // convert fp32 -> gq
+    {
+        const uint64_t t_start = get_time_us();
+
+        quantize(src0, src0_gq, M, K);
+        quantize(src1, src1_gq, N, K);
+
+        const uint64_t t_end = get_time_us();
+        printf("convert time: %f ms\n", (t_end - t_start) / 1000.0);
+    }
+
+    int method = 0;
+    if (argc > 1) {
+        method = atoi(argv[1]);
+    }
+
+    const int nIter = 1;
+
+    const clock_t start = clock();
+    const uint64_t start_us = get_time_us();
+
+    double iM = 1.0/M;
+    double sum = 0.0f;
+    for (int i = 0; i < nIter; i++) {
+        if (method == 0) {
+            mul_mat_vec_f32_0(src0, src1, dst, M, N, K);
+        }
+
+        if (method == 1) {
+            mul_mat_vec_gq_0(src0_gq, src1_gq, dst, M, N, K);
+        }
+    }
+
+    for (int i = 0; i < N; i++) {
+        sum += dst[i]*iM;
+    }
+
+    {
+        const clock_t end = clock();
+        const uint64_t end_us = get_time_us();
+        printf("%s: elapsed ticks: %ld\n",  __func__, end - start);
+        printf("%s: elapsed us:    %llu / %f ms\n",  __func__, end_us - start_us, (end_us - start_us) / 1000.0 / nIter);
+    }
+
+    printf("%f\n", sum);
+
+    free(src0);
+    free(src1);
+    free(dst);
+
+    free(src0_gq);
+    free(src1_gq);
+
+    return 0;
+}