]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : fix I8MM Q4_1 scaling factor conversion (llama/10562)
authorGeorgi Gerganov <redacted>
Fri, 29 Nov 2024 14:25:39 +0000 (16:25 +0200)
committerGeorgi Gerganov <redacted>
Sun, 8 Dec 2024 18:14:35 +0000 (20:14 +0200)
ggml-ci

ggml/src/ggml-cpu/ggml-cpu-quants.c
ggml/src/ggml-cpu/ggml-cpu.c

index 11e8df253d5caca18bd4fd1815c6af16ecc5d601..634c5fa1162c3c4b6e030a713a953679b4f0b273 100644 (file)
@@ -1791,11 +1791,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
             const int8x16_t y1_l = vld1q_s8(b_y1->qs);
             const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
 
-            float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
-                                    GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
-                                    GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
-                                    GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
-
+            float32_t _scale[4] = {
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
+            };
             float32x4_t scale = vld1q_f32(_scale);
 
             int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -1811,7 +1812,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
             int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
 
             sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
-                                                                                l1, r1)), l2, r2)), l3, r3))), scale);
+                                                l1, r1)), l2, r2)), l3, r3))), scale);
         }
 
         float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
@@ -2347,10 +2348,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
             const block_q8_1 * restrict b_y0 = &vy0[i];
             const block_q8_1 * restrict b_y1 = &vy1[i];
 
-            float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
-                                    GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
-                                    GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
-                                    GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
+            float32_t summs_t[4] = {
+                GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
+                GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
+                GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
+                GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)
+            };
             summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
 
             const uint8x16_t m4b = vdupq_n_u8(0x0F);
@@ -2371,10 +2374,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
             const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
 
             // mmla into int32x4_t
-            float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
-                                   GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
-                                   GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
-                                   GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
+            float32_t _scale[4] = {
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
+            };
             float32x4_t scale = vld1q_f32(_scale);
 
             int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -2389,15 +2394,17 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
             int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
             int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
             sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
-                                                                                l1, r1)), l2, r2)), l3, r3))), scale);
+                                                l1, r1)), l2, r2)), l3, r3))), scale);
         }
 
-        float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
+        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
         float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
         sumv2 = vaddq_f32(sumv2, summs0);
 
         vst1_f32(s,      vget_low_f32 (sumv2));
         vst1_f32(s + bs, vget_high_f32(sumv2));
+
         return;
     }
 #endif
@@ -3374,10 +3381,12 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
             const int8x16_t y1_l = vld1q_s8(b_y1->qs);
             const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
 
-            float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
-                                   GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
-                                   GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
-                                   GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
+            float32_t _scale[4] = {
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
+            };
             float32x4_t scale = vld1q_f32(_scale);
 
             int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3393,13 +3402,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
             int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
 
             sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
-                                                                                       l1, r1)), l2, r2)), l3, r3))), scale);
+                                                l1, r1)), l2, r2)), l3, r3))), scale);
         }
-        float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
+
+        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
         float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
 
-        vst1_f32(s, vget_low_f32(sumv2));
+        vst1_f32(s,      vget_low_f32 (sumv2));
         vst1_f32(s + bs, vget_high_f32(sumv2));
+
         return;
     }
 #endif
index 1c88e5d81ab6ccc97208edf3d8d1c56e2db58311..e0cefc20b4d40c7fc79ebdf5ca317a14e644a277 100644 (file)
@@ -7641,8 +7641,8 @@ UseGgmlGemm2:;
         // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
         int64_t num_rows_per_vec_dot = vec_dot_num_rows;
 
-        // TODO: currently the mmla kernels support only even numbered rows/cols.
-        // this check can be removed once they are extended to support odd numbered rows/cols too
+        // these checks are needed to avoid crossing dim1 boundaries
+        // can be optimized, but the logic would become more complicated, so keeping it like this for simplicity
         if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
             num_rows_per_vec_dot = 1;
         }