]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml-cpu: aarm64: q5_K repack gemm and gemv (and generic) implementations (i8mm)...
authorAlberto Cabrera Pérez <redacted>
Fri, 23 Jan 2026 07:55:08 +0000 (07:55 +0000)
committerGeorgi Gerganov <redacted>
Fri, 30 Jan 2026 13:56:40 +0000 (15:56 +0200)
* Boilerplate for q5_Kx8 REPACK on ARM and fallback

Signed-off-by: Alberto Cabrera <redacted>
* Implements make_block_q5_Kx8 by extending make_block_q4_Kx8

Signed-off-by: Alberto Cabrera <redacted>
* q5_K repack gemm and gemv generics

* Gemm and Gemv ARM implementations (i8mm)

* Improved qh manipulation looking at non-repack vec_dot implementation

* Full unroll

* Apply Q5_K Gemv vand and vshl optimizations to gemm. Improve comments.

Signed-off-by: Alberto Cabrera <redacted>
* Fix wrong fallback definitions of Q5_K

Signed-off-by: Alberto Cabrera <redacted>
* Fixed comments. Reverted unnecessary formatting

Signed-off-by: Alberto Cabrera <redacted>
* Fixed typo in generic definitions

* Switching AND + Shift with Shift Insert. Better op interleaving.

* Vectorize + unroll the block scales

* Apply gemm optimizations to gemv

* Improve bias calculation

---------

Signed-off-by: Alberto Cabrera <redacted>
ggml/src/ggml-cpu/arch-fallback.h
ggml/src/ggml-cpu/arch/arm/repack.cpp
ggml/src/ggml-cpu/repack.cpp
ggml/src/ggml-cpu/repack.h

index 3f8946ac701c476be106c3aa94fd4d243a58bf44..0a85a4cff304c7ab2bd45f9bafc3fab103e7f344 100644 (file)
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
index b61220a189a3448bd3649e8c6c3491ff313814a7..883d862901b19cd7a9578aaeba0e6f76d4dad785 100644 (file)
@@ -25,9 +25,8 @@
 #define UNUSED GGML_UNUSED
 
 #if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
-static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
-                                             int16x8_t *     out_mins,
-                                             int8_t *        out_scales) {
+// Helper for decoding scales and mins of Q4_K and Q5_K block formats
+static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
     constexpr uint32_t kmask1 = 0x3f3f3f3f;
     constexpr uint32_t kmask2 = 0x0f0f0f0f;
     constexpr uint32_t kmask3 = 0x03030303;
@@ -561,7 +560,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                 for (int i = 0; i < 2; i++) {
                     int8_t    aux_q4sb[8];
                     const int offset = sb * 24 + i * 12;
-                    decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+                    decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
                     q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
                 }
 
@@ -701,7 +700,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int                        n,
                 for (int i = 0; i < 2; i++) {
                     int8_t    aux_q4sb[8];
                     const int offset = sb * 24 + i * 12;
-                    decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+                    decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
                     q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
                 }
 
@@ -786,6 +785,293 @@ void ggml_gemv_q4_K_8x8_q8_K(int                        n,
     ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemv_q5_K_8x8_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    col_pairs = ncols_interleaved / 2;
+    const uint8x16_t m4b       = vdupq_n_u8(0x0f);
+    const uint8x16_t mone      = vdupq_n_u8(1);
+    const uint8x16_t mtwo      = vdupq_n_u8(2);
+
+    // 1x8 tile = 2 x 4
+    float32x4_t acc_f32[ncols_interleaved / 4];
+
+    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+        for (int i = 0; i < ncols_interleaved / 4; i++) {
+            acc_f32[i] = vdupq_n_f32(0);
+        }
+
+        for (int b = 0; b < nb; b++) {
+            float32x4_t q5_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));      // d0 d1 d2 d3
+            float32x4_t q5_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));  // d4 d5 d6 d7
+            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
+            float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
+            float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
+            float32x4_t q5_dmin_0  = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));      // dmin 0..3
+            float32x4_t q5_dmin_1  = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));  // dmin 4..7
+            float32x4_t sb_min_0   = vmulq_f32(q5_dmin_0, q8_d);
+            float32x4_t sb_min_1   = vmulq_f32(q5_dmin_1, q8_d);
+
+            // 2 sb each iteration
+            int32x4_t acc_lo[col_pairs];
+            int32x4_t acc_hi[col_pairs];
+
+            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
+            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
+            int16_t         bsums_arr[8];
+            vst1q_s16(bsums_arr, bsums);
+
+            // Load qh once per block and shift after each subblock
+            const uint8_t * qh_base = q5_ptr[b].qh;
+            uint8x16_t      qh[col_pairs][4];
+            for (int cp = 0; cp < col_pairs; cp++) {
+                qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
+                qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
+                qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
+                qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
+            }
+
+            for (int sb = 0; sb < QK_K / 64; sb++) {
+                for (int i = 0; i < col_pairs; i++) {
+                    acc_lo[i] = vdupq_n_s32(0);
+                    acc_hi[i] = vdupq_n_s32(0);
+                }
+                // Need scales for the low and high nibbles
+                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                int16x8_t q5sb_mins[2];  // int16 as its needed for bias_acc later
+                int16x8_t q5sb_scales[2];
+                for (int i = 0; i < 2; i++) {
+                    int8_t    aux_q5sb[8];
+                    const int offset = sb * 24 + i * 12;
+                    decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
+                    q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
+                }
+
+                const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
+
+                // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
+                const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
+                int8x16_t      q8_qs[8];
+                for (int i = 0; i < 8; i++) {
+                    q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
+                }
+
+                // Q5s column pair loop unrolled
+                {
+                    // Cols 01
+                    uint8x16_t qs_0 = vld1q_u8(qs_base);
+                    uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
+                    uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
+                    uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
+
+                    uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
+                    uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
+                    uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
+                    uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
+                    uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
+                    uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
+                    uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
+                    uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
+
+                    qh[0][0] = vshrq_n_u8(qh[0][0], 2);
+                    qh[0][1] = vshrq_n_u8(qh[0][1], 2);
+                    qh[0][2] = vshrq_n_u8(qh[0][2], 2);
+                    qh[0][3] = vshrq_n_u8(qh[0][3], 2);
+
+                    acc_lo[0] = ggml_vdotq_s32(
+                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+                    acc_lo[0] = ggml_vdotq_s32(
+                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+                    acc_lo[0] = ggml_vdotq_s32(
+                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+                    acc_lo[0] = ggml_vdotq_s32(
+                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+                                               q8_qs[4]);
+                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+                                               q8_qs[5]);
+                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+                                               q8_qs[6]);
+                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+                                               q8_qs[7]);
+
+                    // Cols 23
+                    qs_0 = vld1q_u8(qs_base + 16);
+                    qs_1 = vld1q_u8(qs_base + 80);
+                    qs_2 = vld1q_u8(qs_base + 144);
+                    qs_3 = vld1q_u8(qs_base + 208);
+
+                    hbit_lo_0 = vandq_u8(qh[1][0], mone);
+                    hbit_lo_1 = vandq_u8(qh[1][1], mone);
+                    hbit_lo_2 = vandq_u8(qh[1][2], mone);
+                    hbit_lo_3 = vandq_u8(qh[1][3], mone);
+                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
+                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
+                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
+                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
+
+                    qh[1][0] = vshrq_n_u8(qh[1][0], 2);
+                    qh[1][1] = vshrq_n_u8(qh[1][1], 2);
+                    qh[1][2] = vshrq_n_u8(qh[1][2], 2);
+                    qh[1][3] = vshrq_n_u8(qh[1][3], 2);
+
+                    acc_lo[1] = ggml_vdotq_s32(
+                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+                    acc_lo[1] = ggml_vdotq_s32(
+                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+                    acc_lo[1] = ggml_vdotq_s32(
+                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+                    acc_lo[1] = ggml_vdotq_s32(
+                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+                                               q8_qs[4]);
+                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+                                               q8_qs[5]);
+                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+                                               q8_qs[6]);
+                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+                                               q8_qs[7]);
+
+                    // Cols 45
+                    qs_0 = vld1q_u8(qs_base + 32);
+                    qs_1 = vld1q_u8(qs_base + 96);
+                    qs_2 = vld1q_u8(qs_base + 160);
+                    qs_3 = vld1q_u8(qs_base + 224);
+
+                    hbit_lo_0 = vandq_u8(qh[2][0], mone);
+                    hbit_lo_1 = vandq_u8(qh[2][1], mone);
+                    hbit_lo_2 = vandq_u8(qh[2][2], mone);
+                    hbit_lo_3 = vandq_u8(qh[2][3], mone);
+                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
+                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
+                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
+                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
+
+                    qh[2][0] = vshrq_n_u8(qh[2][0], 2);
+                    qh[2][1] = vshrq_n_u8(qh[2][1], 2);
+                    qh[2][2] = vshrq_n_u8(qh[2][2], 2);
+                    qh[2][3] = vshrq_n_u8(qh[2][3], 2);
+
+                    acc_lo[2] = ggml_vdotq_s32(
+                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+                    acc_lo[2] = ggml_vdotq_s32(
+                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+                    acc_lo[2] = ggml_vdotq_s32(
+                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+                    acc_lo[2] = ggml_vdotq_s32(
+                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+                                               q8_qs[4]);
+                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+                                               q8_qs[5]);
+                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+                                               q8_qs[6]);
+                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+                                               q8_qs[7]);
+
+                    // Cols 45
+                    qs_0 = vld1q_u8(qs_base + 48);
+                    qs_1 = vld1q_u8(qs_base + 112);
+                    qs_2 = vld1q_u8(qs_base + 176);
+                    qs_3 = vld1q_u8(qs_base + 240);
+
+                    hbit_lo_0 = vandq_u8(qh[3][0], mone);
+                    hbit_lo_1 = vandq_u8(qh[3][1], mone);
+                    hbit_lo_2 = vandq_u8(qh[3][2], mone);
+                    hbit_lo_3 = vandq_u8(qh[3][3], mone);
+                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
+                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
+                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
+                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
+
+                    qh[3][0] = vshrq_n_u8(qh[3][0], 2);
+                    qh[3][1] = vshrq_n_u8(qh[3][1], 2);
+                    qh[3][2] = vshrq_n_u8(qh[3][2], 2);
+                    qh[3][3] = vshrq_n_u8(qh[3][3], 2);
+
+                    acc_lo[3] = ggml_vdotq_s32(
+                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+                    acc_lo[3] = ggml_vdotq_s32(
+                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+                    acc_lo[3] = ggml_vdotq_s32(
+                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+                    acc_lo[3] = ggml_vdotq_s32(
+                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+                                               q8_qs[4]);
+                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+                                               q8_qs[5]);
+                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+                                               q8_qs[6]);
+                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+                                               q8_qs[7]);
+                }
+
+                // Prepare bsum vectors for bias computation
+                // Each pair of subblocks share the same bsums
+                int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
+                int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
+
+                // Iterates over a pair of column pairs (4 columns) to use a single 128 register
+                // p = 0 -> 0123  p2 -> 4567
+                for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
+                    int16x4_t   group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
+                    int16x4_t   group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
+                    int16x4_t   group_mins_lo   = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
+                    int16x4_t   group_mins_hi   = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
+                    float32x4_t sb_scale        = p == 0 ? sb_scale_0 : sb_scale_1;
+                    float32x4_t sb_min          = p == 0 ? sb_min_0 : sb_min_1;
+
+                    // 0123 or 4567
+                    float32x4_t sumf_0 =
+                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
+                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
+
+                    float32x4_t sumf_1 =
+                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
+                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
+
+                    // FUSED BIAS: Compute and subtract bias immediately
+                    // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
+                    int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo);
+                    bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
+                    float32x4_t bias_f32 = vcvtq_f32_s32(bias);
+                    acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
+                }
+            }  // for sb
+        }  // for b
+
+        int base = x * ncols_interleaved;
+        vst1q_f32(s + base, acc_f32[0]);
+        vst1q_f32(s + base + 4, acc_f32[1]);
+    }  // for x
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemv_q8_0_4x4_q8_0(int                        n,
                              float * GGML_RESTRICT      s,
                              size_t                     bs,
@@ -2431,7 +2717,7 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     for (int i = 0; i < 2; i++) {
                         int8_t    aux_q4sb[8];
                         const int offset = sb * 24 + i * 12;
-                        decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+                        decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
                         q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
                     }
 
@@ -2595,7 +2881,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int                        n,
                     int16x8_t q4sb_mins[2];  // int16 as its needed for bias_acc later
                     for (int i = 0; i < 2; i++) {
                         const int offset = sb * 24 + i * 12;
-                        decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
+                        decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
                     }
 
                     // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
@@ -2738,6 +3024,252 @@ void ggml_gemm_q4_K_8x8_q8_K(int                        n,
     ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemm_q5_K_8x8_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    constexpr int    q8_k_blocklen = 4;
+    constexpr int    col_pairs     = ncols_interleaved / 2;
+    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
+    const uint8x16_t mone          = vdupq_n_u8(1);
+    const uint8x16_t mtwo          = vdupq_n_u8(2);
+
+    // 8 accumulators: 2 row pairs Ă— 4 col pairs
+    float32x4_t acc_f32[blocklen];
+
+    for (int y = 0; y < nr / q8_k_blocklen; y++) {
+        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+            for (int i = 0; i < blocklen; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
+            }
+
+            for (int b = 0; b < nb; b++) {
+                // bsums pairs belongs to the same q8_k subblock
+                const int16x8_t bsums[4]{
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
+                };
+                int16_t bsums_arr[4][8];
+                for (int q8_row = 0; q8_row < 4; q8_row++) {
+                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
+                }
+
+                int32x4_t sb_acc[4];    // Aux accumulators to store subblock (partial) results
+                int32x4_t acc[8];       // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
+                int32x4_t bias_acc[8];  // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
+                for (int i = 0; i < 8; i++) {
+                    acc[i]      = vdupq_n_s32(0);
+                    bias_acc[i] = vdupq_n_s32(0);
+                }
+
+                // Load qh once per block and shift after each subblock
+                const uint8_t * qh_base = q5_ptr[b].qh;
+                uint8x16_t      qh[col_pairs][4];
+                for (int cp = 0; cp < col_pairs; cp++) {
+                    qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
+                    qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
+                    qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
+                    qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
+                }
+
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    // Need scales for the low and high nibbles
+                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                    int8_t    q5sb_scales[2][8];
+                    int16x8_t q5sb_mins[2];  // int16 as its needed for bias_acc later
+                    for (int i = 0; i < 2; i++) {
+                        const int offset = sb * 24 + i * 12;
+                        decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
+                    }
+
+                    // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
+                    const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
+
+                    int8x16_t q8_qs_01[8];
+                    int8x16_t q8_qs_23[8];
+
+                    // Load 32-byte per row pair, 1 subblock each time
+                    for (int i = 0; i < 8; i++) {
+                        const int offset = i * 32;  // 16 for row 01, 16 for row 23
+                        q8_qs_01[i]      = vld1q_s8(q8_base + offset);
+                        q8_qs_23[i]      = vld1q_s8(q8_base + offset + 16);
+                    }
+
+                    const int8x16_t q8s[2][8] = {
+                        { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
+                         q8_qs_01[7] },
+                        { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
+                         q8_qs_23[7] },
+                    };
+
+                    // Q5s columns iterated in pairs (01, 23, 45, 67)
+                    for (int cp = 0; cp < col_pairs; cp++) {
+                        for (int i = 0; i < 4; i++) {
+                            sb_acc[i] = vdupq_n_s32(0);
+                        }
+
+                        uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0);    // 0 .. 7 & 32..39
+                        uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64);   // 8 ..15 & 40..47
+                        uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128);  // 16..23 & 48..55
+                        uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192);  // 24..31 & 56..63
+
+                        // This is the only part of the algorithm that differs with Q4_K
+                        // Extract High bits and pack into 5 bit weights
+                        uint8x16_t hbit_lo_0    = vandq_u8(qh[cp][0], mone);
+                        uint8x16_t hbit_hi_0    = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
+                        qh[cp][0]               = vshrq_n_u8(qh[cp][0], 2);
+                        // Same as Q4_K, i8mm to dequantize the weights.
+                        const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
+                        int32x4_t       acc_0   = sb_acc[0];
+                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
+                        int32x4_t acc_2 = sb_acc[2];
+                        acc_2           = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
+                        const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
+                        int32x4_t       acc_1   = sb_acc[1];
+                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
+                        int32x4_t acc_3         = sb_acc[3];
+                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
+
+                        // Repeat for the other 3 columns (8..15, 16..23, 24..31)
+                        uint8x16_t hbit_hi_1    = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
+                        uint8x16_t hbit_lo_1    = vandq_u8(qh[cp][1], mone);
+                        qh[cp][1]               = vshrq_n_u8(qh[cp][1], 2);
+                        const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
+                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
+                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
+                        const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
+                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
+                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
+
+                        uint8x16_t hbit_hi_2    = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
+                        uint8x16_t hbit_lo_2    = vandq_u8(qh[cp][2], mone);
+                        qh[cp][2]               = vshrq_n_u8(qh[cp][2], 2);
+                        const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
+                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
+                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
+                        const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
+                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
+                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
+
+                        uint8x16_t hbit_lo_3    = vandq_u8(qh[cp][3], mone);
+                        uint8x16_t hbit_hi_3    = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
+                        qh[cp][3]               = vshrq_n_u8(qh[cp][3], 2);
+                        const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
+                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
+                        sb_acc[0]               = acc_0;
+                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
+                        sb_acc[2]               = acc_2;
+
+                        // Scales[i] corresponds to column i
+                        const int       scale_offset = cp * 2;
+                        const int32_t   s0           = q5sb_scales[0][scale_offset];
+                        const int32_t   s1           = q5sb_scales[0][scale_offset + 1];
+                        const int32x4_t block_scale  = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
+                        acc[cp]                      = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
+                        acc[cp + 4]                  = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
+
+                        const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
+                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
+                        sb_acc[1]               = acc_1;
+                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
+                        sb_acc[3]               = acc_3;
+
+                        const int32_t   s2           = q5sb_scales[1][scale_offset];
+                        const int32_t   s3           = q5sb_scales[1][scale_offset + 1];
+                        const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
+                        acc[cp]                      = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
+                        acc[cp + 4]                  = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
+                    }
+
+                    // Multiply Acc bsum + mins
+                    for (int q8_row = 0; q8_row < 4; q8_row++) {
+                        // Each pair of subblocks share the same bsums
+                        // Load scalar bsum â†’ broadcast to a vector (vdupq_n_s16(s)).
+                        int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
+                        int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
+
+                        bias_acc[2 * q8_row] =
+                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
+                        bias_acc[2 * q8_row] =
+                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
+                        bias_acc[2 * q8_row + 1] =
+                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
+                        bias_acc[2 * q8_row + 1] =
+                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
+                    }
+                }  // for sb
+
+                // Reorder of i8mm output with bias and output layout
+                for (int i = 0; i < 8; i++) {
+                    int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
+                    acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);
+                }
+                int32x4_t reorder_acc[8] = {
+                    vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
+                    vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
+                    vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
+                    vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
+                    vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
+                    vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
+                    vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
+                    vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
+                };
+
+                for (int i = 0; i < q8_k_blocklen; i++) {
+                    for (int j = 0; j < 2; j++) {
+                        float32x4_t       q8_d    = vdupq_n_f32(q8_ptr[b].d[i]);
+                        float32x4_t       q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
+                        const float32x4_t dmins   = vmulq_f32(q5_dmin, q8_d);
+
+                        float32x4_t       q5_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
+                        const float32x4_t scale = vmulq_f32(q5_d, q8_d);
+
+                        acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
+                        acc_f32[2 * i + j] =
+                            vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
+                    }
+                }
+            }  // for b
+
+            // With the previous reorder, the tile is already in the correct memory layout.
+            for (int i = 0; i < q8_k_blocklen; i++) {
+                int row = y * q8_k_blocklen + i;
+                for (int j = 0; j < 2; j++) {
+                    int col    = x * ncols_interleaved + j * 4;
+                    int offset = row * bs + col;
+                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
+                }
+            }
+        }  // for x
+    }  // for y
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
 
 void ggml_gemm_q8_0_4x4_q8_0(int                        n,
                              float * GGML_RESTRICT      s,
index fbf7ed9432ab95b8647b98ca94075dec1914d1e5..19e021e59aab3cbfd80b488a0cca2590511f184a 100644 (file)
@@ -474,15 +474,8 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     assert (n % qk == 0);
     assert (nc % ncols_interleaved == 0);
 
-    UNUSED(s);
     UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
     UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
 
     float sumf[8];
     float sum_minf[8];
@@ -616,6 +609,100 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     }
 }
 
+void ggml_gemv_q5_K_8x8_q8_K_generic(int                        n,
+                                     float * GGML_RESTRICT      s,
+                                     size_t                     bs,
+                                     const void * GGML_RESTRICT vx,
+                                     const void * GGML_RESTRICT vy,
+                                     int                        nr,
+                                     int                        nc) {
+    const int             qk                = QK_K;
+    const int             nb                = n / qk;
+    const int             ncols_interleaved = 8;
+    const int             blocklen          = 8;
+    static const uint32_t kmask1            = 0x3f3f3f3f;
+    static const uint32_t kmask2            = 0x0f0f0f0f;
+    static const uint32_t kmask3            = 0x03030303;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float    sumf[8];
+    float    sum_minf[8];
+    uint32_t utmp[32];
+    int      sumi1;
+    int      sumi2;
+    int      sumi;
+
+    const block_q8_K * a_ptr = (const block_q8_K *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j]     = 0.0;
+            sum_minf[j] = 0.0;
+        }
+        for (int l = 0; l < nb; l++) {
+            for (int sb = 0; sb < 8; sb++) {
+                memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
+                utmp[sb * 4 + 3]      = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
+                const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
+                utmp[sb * 4 + 1]      = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
+                utmp[sb * 4 + 2]      = uaux_0;
+                utmp[sb * 4 + 0] &= kmask1;
+            }
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
+                uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
+
+                const int qh_shift = (k / 4) * 2;
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi1 = 0;
+                    sumi2 = 0;
+                    sumi  = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
+
+                        const int qh_idx      = (k * 8 + i) % 32;
+                        const int qh_chunk    = qh_idx / 8;
+                        const int qh_pos      = qh_idx % 8;
+                        const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
+
+                        const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
+                        const uint8_t h0     = (qh_val >> qh_shift) & 1;
+                        const uint8_t h1     = (qh_val >> (qh_shift + 1)) & 1;
+
+                        const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
+                        const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
+
+                        const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i;
+
+                        sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
+                        sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
+                        sumi1 = sumi1 * scales_0[j];
+                        sumi2 = sumi2 * scales_1[j];
+                        sumi += sumi1 + sumi2;
+                    }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
+                }
+            }
+            for (int sb = 0; sb < 8; sb++) {
+                uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
+                                   GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
+        }
+    }
+}
+
 void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -1212,6 +1299,108 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     }
 }
 
+void ggml_gemm_q5_K_8x8_q8_K_generic(int                        n,
+                                     float * GGML_RESTRICT      s,
+                                     size_t                     bs,
+                                     const void * GGML_RESTRICT vx,
+                                     const void * GGML_RESTRICT vy,
+                                     int                        nr,
+                                     int                        nc) {
+    const int qk                = QK_K;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen          = 8;
+
+    constexpr uint32_t kmask1 = 0x3f3f3f3f;
+    constexpr uint32_t kmask2 = 0x0f0f0f0f;
+    constexpr uint32_t kmask3 = 0x03030303;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float    sumf[4][8];
+    float    sum_minf[4][8];
+    uint32_t utmp[32];
+    int      sumi1;
+    int      sumi2;
+    int      sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j]     = 0.0;
+                    sum_minf[m][j] = 0.0;
+                }
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int sb = 0; sb < 8; sb++) {
+                    memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
+                    utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
+                    utmp[sb * 4 + 1]      = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
+                    utmp[sb * 4 + 2]      = uaux_0;
+                    utmp[sb * 4 + 0] &= kmask1;
+                }
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
+                    uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
+
+                    const int qh_shift = (k / 4) * 2;
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi1 = 0;
+                            sumi2 = 0;
+                            sumi  = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
+
+                                const int qh_idx      = (k * 8 + i) % 32;
+                                const int qh_chunk    = qh_idx / 8;
+                                const int qh_pos      = qh_idx % 8;
+                                const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
+
+                                const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
+                                const uint8_t h0     = (qh_val >> qh_shift) & 1;
+                                const uint8_t h1     = (qh_val >> (qh_shift + 1)) & 1;
+
+                                const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
+                                const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
+
+                                const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i;
+
+                                sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
+                                sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
+                                sumi1 = sumi1 * scales_0[j];
+                                sumi2 = sumi2 * scales_1[j];
+                                sumi += sumi1 + sumi2;
+                            }
+                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
+                        }
+                    }
+                }
+                for (int sb = 0; sb < 8; sb++) {
+                    uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
+                    for (int m = 0; m < 4; m++) {
+                        const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
+                                              GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
+                }
+            }
+        }
+    }
+}
 
 void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
@@ -1622,7 +1811,95 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in
         out.scales[i] = in[src1].scales[src2];
     }
     return out;
+}
+
+static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) {
+    block_q5_Kx8 out;
+    //Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure
+    for (int i = 0; i < 8; i++) {
+        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
+    }
+
+    for (int i = 0; i < 8; i++) {
+        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
+    }
 
+    const int end = QK_K * 4 / blck_size_interleave;
+
+    // Interleave Q5_K quants by taking 8 bytes at a time
+    for (int i = 0; i < end; ++i) {
+        int src_id     = i % 8;
+        int src_offset = (i / 8) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        uint64_t elems;
+        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
+        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
+    }
+
+    // Repeat for low bits 8 bytes at a time as well, since
+    // the high bits are interleaved in Q5_K and the index is
+    // qh_idx = (qs_idx % 32);
+    // qh_val = qh[qh_idx] >> (qs_idx / 32);
+    for (int i = 0; i < end / 4; ++i) {
+        int src_id     = i % 8;
+        int src_offset = (i / 8) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        uint64_t elems;
+        memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t));
+        memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t));
+    }
+
+    // The below logic is copied over from Q4_K
+    // The point is to unpack all the scales and mins for each sub block every time we load 12 bytes.
+    // Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
+    // The output Q5_Kx8 structure has 96 bytes
+    // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure
+    // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures
+    uint8_t s[8], m[8];
+
+    for (int i = 0; i < 4; i++) {
+        for (int j = 0; j < 8; j++) {
+            s[j] = in[j].scales[i] & 63;
+            m[j] = in[j].scales[i + 4] & 63;
+        }
+
+        out.scales[i * 12]      = (s[0] & 63) + ((s[4] & 48) << 2);
+        out.scales[i * 12 + 1]  = (s[1] & 63) + ((s[5] & 48) << 2);
+        out.scales[i * 12 + 2]  = (s[2] & 63) + ((s[6] & 48) << 2);
+        out.scales[i * 12 + 3]  = (s[3] & 63) + ((s[7] & 48) << 2);
+        out.scales[i * 12 + 4]  = (m[0] & 63) + ((m[4] & 48) << 2);
+        out.scales[i * 12 + 5]  = (m[1] & 63) + ((m[5] & 48) << 2);
+        out.scales[i * 12 + 6]  = (m[2] & 63) + ((m[6] & 48) << 2);
+        out.scales[i * 12 + 7]  = (m[3] & 63) + ((m[7] & 48) << 2);
+        out.scales[i * 12 + 8]  = (s[4] & 15) + ((m[4] & 15) << 4);
+        out.scales[i * 12 + 9]  = (s[5] & 15) + ((m[5] & 15) << 4);
+        out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
+        out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
+    }
+
+    for (int i = 0; i < 4; i++) {
+        for (int j = 0; j < 8; j++) {
+            s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15);
+            m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4);
+        }
+
+        out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
+        out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
+        out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
+        out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
+        out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
+        out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
+        out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
+        out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
+        out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
+        out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
+        out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
+        out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
+    }
+
+    return out;
 }
 
 static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
@@ -1718,6 +1995,38 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block
     GGML_UNUSED(data_size);
 }
 
+static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor *       t,
+                                    int                        interleave_block,
+                                    const void * GGML_RESTRICT data,
+                                    size_t                     data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
+    GGML_ASSERT(interleave_block == 8);
+    constexpr int nrows_interleaved = 8;
+
+    block_q5_Kx8 *     dst = (block_q5_Kx8 *) t->data;
+    const block_q5_K * src = (const block_q5_K *) data;
+    block_q5_K         dst_tmp[8];
+    int                nrow    = ggml_nrows(t);
+    int                nblocks = t->ne[0] / QK_K;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+}
+
 static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
     GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
     GGML_ASSERT(interleave_block == 8);
@@ -1936,6 +2245,10 @@ template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * da
     return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
 }
 
+template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
+}
+
 template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
     return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
 }
@@ -1973,6 +2286,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
     ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
 template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
 }
@@ -1981,8 +2298,8 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
     ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
-    ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
 template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
@@ -2013,20 +2330,24 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
     ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
-    ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
-}
-
 template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
 template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
-    ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
 template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
@@ -2432,6 +2753,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
     static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
     static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
 
+    // instance for Q5_K
+    static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
+
     // instance for Q2
     static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
 
@@ -2482,6 +2806,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
                 return &q2_K_8x8_q8_K;
             }
         }
+    } else if (cur->type == GGML_TYPE_Q5_K) {
+        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &q5_K_8x8_q8_K;
+            }
+        }
     } else if (cur->type == GGML_TYPE_IQ4_NL) {
         if (ggml_cpu_has_avx2()) {
             if (cur->ne[1] % 8 == 0) {
index af98e70344287bccd5216470c8b64c2cd5e9f4e3..da87103157e7215e853611eb25fff3c3a12069ef 100644 (file)
@@ -44,6 +44,7 @@ struct block_q4_Kx8 {
 };
 
 static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
+
 struct block_q2_Kx8 {
     ggml_half d[8];      // super-block scale for quantized scales
     ggml_half dmin[8];   // super-block scale for quantized mins
@@ -52,6 +53,18 @@ struct block_q2_Kx8 {
 };
 
 static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
+
+struct block_q5_Kx8 {
+    ggml_half d[8];              // super-block scale for quantized scales
+    ggml_half dmin[8];           // super-block scale for quantized mins
+    uint8_t   scales[96];        // scales and mins, quantized with 6 bits
+    uint8_t   qh[QK_K * 8 / 8];  // high bits of 5-bit quants
+    uint8_t   qs[QK_K * 8 / 2];  // low bits of 5-bit quants (in groups of 4)
+};
+
+static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5,
+              "wrong q5_K block size/padding");
+
 struct block_q8_Kx4 {
     float d[4];              // delta
     int8_t qs[QK_K * 4];     // quants
@@ -82,20 +95,22 @@ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTR
 void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
 void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
 void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
+void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -111,17 +126,19 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
 void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);