]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
tests : wip quantized matrix multiplication method 2
authorGeorgi Gerganov <redacted>
Sat, 7 Jan 2023 07:36:32 +0000 (09:36 +0200)
committerGeorgi Gerganov <redacted>
Sat, 7 Jan 2023 07:37:23 +0000 (09:37 +0200)
CMakeLists.txt
tests/test-mul-mat2.c

index d2f95ccef73df36c787f4226bc728d3c42e3694e..d88c5b101348b8aaef6cd87229474934488268dc 100644 (file)
@@ -47,6 +47,7 @@ endif()
 
 #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math")
 #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
+#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native")
 
 # dependencies
 
index c43479c12edf68dbed5fd1798efa3c3c13942d64..94a5e2768d3f05dc6a3ec1b42886628e2006ef34 100644 (file)
@@ -42,7 +42,7 @@ uint64_t get_time_us() {
 // naive implementation
 //
 
-void mul_mat_vec_f32_0(
+void mul_mat_vec_f32_naive(
     const float * restrict src0, // M x K
     const float * restrict src1, // N x K (transposed)
     float * dst,
@@ -58,7 +58,11 @@ void mul_mat_vec_f32_0(
     }
 }
 
-void quantize(const float * src, void * dst, int n, int k) {
+//
+// method 1
+//
+
+void quantize_1(const float * src, void * dst, int n, int k) {
     char * p0 = dst;
 
     gq_t pp[QB];
@@ -128,7 +132,7 @@ void quantize(const float * src, void * dst, int n, int k) {
     }
 }
 
-void mul_mat_vec_gq_0(
+void mul_mat_vec_gq_1(
     const void * src0,
     const void * src1,
          float * dst,
@@ -138,6 +142,12 @@ void mul_mat_vec_gq_0(
     const char * restrict p0 = src0;
     const char * restrict p1 = src1;
 
+    float s0[QB + 1];
+    float s1[QB + 1];
+
+    gq_t m0[QB + 1];
+    gq_t m1[QB + 1];
+
     for (int ir0 = 0; ir0 < m; ir0++) {
         for (int ir1 = 0; ir1 < n; ir1++) {
             float sumf = 0.0;
@@ -159,9 +169,6 @@ void mul_mat_vec_gq_0(
 #if 1
                 // >>> General case for any QB
 
-                float s0[QB + 1];
-                float s1[QB + 1];
-
                 s0[0] = min0;
                 s1[0] = min1;
 
@@ -170,8 +177,146 @@ void mul_mat_vec_gq_0(
                     s1[b + 1] = d1*(1 << b);
                 }
 
-                gq_t m0[QB + 1];
-                gq_t m1[QB + 1];
+                m0[0] = -1LL;
+                m1[0] = -1LL;
+
+                for (int s = 0; s < QK/gq_t_bits; ++s) {
+                    for (int b = 0; b < QB; b++) {
+                        memcpy(&m0[b + 1], pp0, sizeof(gq_t)); pp0 += sizeof(gq_t);
+                        memcpy(&m1[b + 1], pp1, sizeof(gq_t)); pp1 += sizeof(gq_t);
+                    }
+
+                    for (int q0 = 0; q0 < QB + 1; q0++) {
+                        for (int q1 = 0; q1 < QB + 1; q1++) {
+                            sumf += s0[q0]*s1[q1]*__builtin_popcountll(m0[q0] & m1[q1]);
+                        }
+                    }
+                }
+#else
+#endif
+            }
+
+            dst[ir0*n + ir1] = sumf;
+        }
+    }
+}
+
+//
+// method 2
+//
+
+void quantize_2(const float * src, void * dst, int n, int k) {
+    char * p0 = dst;
+
+    for (int j = 0; j < n; j++) {
+        for (int i = 0; i < k/QK; i++) {
+            float min = FLT_MAX;
+            float max = -FLT_MAX;
+
+            // find min/max
+#ifdef __ARM_NEON
+            {
+                float32x4_t minv = vdupq_n_f32(FLT_MAX);
+                float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
+
+                for (int l = 0; l < QK; l += 4) {
+                    float32x4_t v = vld1q_f32(src + j*k + i*QK + l);
+                    minv = vminq_f32(minv, v);
+                    maxv = vmaxq_f32(maxv, v);
+                }
+
+                float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
+                float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
+
+                min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
+                max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
+
+                //printf("SIMD min/max: %f %f\n", min, max);
+            }
+#else
+            {
+                for (int l = 0; l < QK; l++) {
+                    const float v = src[j*k + i*QK + l];
+                    if (v < min) min = v;
+                    if (v > max) max = v;
+                }
+
+                //printf("NORM min/max: %f %f\n", min, max);
+            }
+#endif
+
+            const float d = (max - min) / ((1 << QB) - 1);
+            const float id = d ? 1.0/d : 0.0;
+
+            memcpy(p0, &min, sizeof(float)); p0 += sizeof(float);
+            memcpy(p0, &d,   sizeof(float)); p0 += sizeof(float);
+
+            //printf("min/max/d/id: %f %f %f %f\n", min, max, d, id);
+
+            for (int s = 0; s < QK/gq_t_bits; ++s) {
+                gq_t pp[QB] = {0};
+
+                for (int l = 0; l < gq_t_bits; l++) {
+                    const   float v = src[j*k + i*QK + s*gq_t_bits + l];
+                    const uint8_t q = (v - min)*id;
+
+                    for (int b = 0; b < QB; b++) {
+                        pp[b] |= q & (1 << b) ? (1LL << l) : 0;
+                    }
+                }
+
+                for (int b = 0; b < QB; b++) {
+                    memcpy(p0, &pp[b], sizeof(gq_t)); p0 += sizeof(gq_t);
+                }
+            }
+        }
+    }
+}
+
+void mul_mat_vec_gq_2(
+    const void * src0,
+    const void * src1,
+         float * dst,
+    int m, int n, int k) {
+    const int kp = k & ~(gq_t_bits - 1);
+
+    const char * restrict p0 = src0;
+    const char * restrict p1 = src1;
+
+    float s0[QB + 1];
+    float s1[QB + 1];
+
+    gq_t m0[QB + 1];
+    gq_t m1[QB + 1];
+
+    for (int ir0 = 0; ir0 < m; ir0++) {
+        for (int ir1 = 0; ir1 < n; ir1++) {
+            float sumf = 0.0;
+
+            const char * restrict pp0 = p0 + ir0*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK));
+            const char * restrict pp1 = p1 + ir1*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK));
+
+            for (int i = 0; i < kp/QK; i++) {
+                float min0, d0;
+                memcpy(&min0, pp0, sizeof(float)); pp0 += sizeof(float);
+                memcpy(&d0,   pp0, sizeof(float)); pp0 += sizeof(float);
+
+                float min1, d1;
+                memcpy(&min1, pp1, sizeof(float)); pp1 += sizeof(float);
+                memcpy(&d1,   pp1, sizeof(float)); pp1 += sizeof(float);
+
+                //printf("min0/d0 = %f %f | min1/d1 = %f %f\n", min0, d0, min1, d1);
+
+#if 1
+                // >>> General case for any QB
+
+                s0[0] = min0;
+                s1[0] = min1;
+
+                for (int b = 0; b < QB; b++) {
+                    s0[b + 1] = d0*(1 << b);
+                    s1[b + 1] = d1*(1 << b);
+                }
 
                 m0[0] = -1LL;
                 m1[0] = -1LL;
@@ -198,6 +343,8 @@ void mul_mat_vec_gq_0(
 }
 
 int main(int argc, const char ** argv) {
+    assert(sizeof(gq_t)*8 == gq_t_bits);
+
     float * src0 = (float *)malloc(sizeof(float)*M*K);
     float * src1 = (float *)malloc(sizeof(float)*N*K);
     float * dst  = (float *)malloc(sizeof(float)*M*N);
@@ -219,20 +366,27 @@ int main(int argc, const char ** argv) {
 
     printf("compression: %f\n", (float)sizegq/sizef16);
 
+    int method = 0;
+    if (argc > 1) {
+        method = atoi(argv[1]);
+    }
+
     // convert fp32 -> gq
     {
         const uint64_t t_start = get_time_us();
 
-        quantize(src0, src0_gq, M, K);
-        quantize(src1, src1_gq, N, K);
+        if (method == 1) {
+            quantize_1(src0, src0_gq, M, K);
+            quantize_1(src1, src1_gq, N, K);
+        }
 
-        const uint64_t t_end = get_time_us();
-        printf("convert time: %f ms\n", (t_end - t_start) / 1000.0);
-    }
+        if (method == 2) {
+            quantize_2(src0, src0_gq, M, K);
+            quantize_2(src1, src1_gq, N, K);
+        }
 
-    int method = 0;
-    if (argc > 1) {
-        method = atoi(argv[1]);
+        const uint64_t t_end = get_time_us();
+        printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method);
     }
 
     const int nIter = 1;
@@ -244,11 +398,15 @@ int main(int argc, const char ** argv) {
     double sum = 0.0f;
     for (int i = 0; i < nIter; i++) {
         if (method == 0) {
-            mul_mat_vec_f32_0(src0, src1, dst, M, N, K);
+            mul_mat_vec_f32_naive(src0, src1, dst, M, N, K);
         }
 
         if (method == 1) {
-            mul_mat_vec_gq_0(src0_gq, src1_gq, dst, M, N, K);
+            mul_mat_vec_gq_1(src0_gq, src1_gq, dst, M, N, K);
+        }
+
+        if (method == 2) {
+            mul_mat_vec_gq_1(src0_gq, src1_gq, dst, M, N, K);
         }
     }