]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
IQ1_M: 1.75 bpw quantization (llama/6302)
authorKawrakow <redacted>
Tue, 26 Mar 2024 14:21:27 +0000 (15:21 +0100)
committerGeorgi Gerganov <redacted>
Wed, 27 Mar 2024 11:20:00 +0000 (13:20 +0200)
* iq1_m: basics

* iq1_m: basics-2

* iq1_m: CUDA dequantize works

Very 1st shot I get PPL = 9.76 for LLaMA-v2-7B.

* iq1_m: separate shifts for each group of 8 in a block

We get
PPL(LLaMA-v2-7B ) = 9.2810
PPL(LLaMA-v2-13B) = 6.8105

Not bad, but slightly higher than
  sqrt(PPL(IQ1_S) * PPL(IQ2_XXS))
which is the expected outcome given that IQ1_M is
halfway between IQ1_S and IQ2_XXS in terms of bpw.
From this, we would expect
 PPL = 9.14 for LLaMA-v2-7B
 PPL = 6.63 for LLaMA-v2-13B

* iq1_m: go to 3-bit scales

There is slight increase in PPL, but the 0.0625 bpw reduction
in size is totally worth it.

We now have
PPL(LLaMA-v2-7B ) = 9.4469 at 1.96 bpw
PPL(LLaMA-v2-13B) = 6.8717 at 1.93 bpw
PPL(LLaMA-v2-70B) = 4.8568 at 1.85 bpw

* iq1_m: scalar dot product

* iq1_m: AVX2 dot product

* iq1_m: very slightly faster AVX2 dot product

* iq1_m: ARM_NEON dot product

Works, but very slow (10.5 t/s)

* iq1_m: Metal - dequantize works, dot product does not

* iq1_m: Metal now works

About the same performance as iq1_s.

* iq1_m: minor

* iq1_m: checking pure iq1_m quantization

It is pretty bad: PPL(LLaMA-v2-7B) = 34 if we quantize output.weight
with Q4_K.

* iiq1_m: slightly faster ARM_NEON dot product

10.5 t/s -> 11.65 t/s

* iq1_m: faster ARM_NEON dot product

11.65 t/s -> 14.9 t/s

* iq1_m: another minor ARM_NEON dot product improvement

14.9 -> 15.0 t/s

* iq1_m: small PPL improvement via super-block scale adjustment

After quantizing block scales redo the super-block scale fit.

PPL(LLaMA-v2-7B ) = 9.3346
PPL(LLaMA-v2-13B) = 6.8419
PPL(LLaMA-v2-70B) = 4.8294
PPL(Mistral-7B  ) = 8.1624

* iq1_m: adapt to CUDA refactoring

* iq1_m: remove unused variable

We have progressed to warnings being errors.

* iq1_m: add to backend-ops tests

* iq1_m: fix Windows ARM

* iq1_m: use common definition of iq1m_scale_t

* cuda: assert -> NO_DEVICE_CODE

* iq1_M: PR comments

---------

Co-authored-by: Iwan Kawrakow <redacted>
include/ggml/ggml.h
src/ggml-common.h
src/ggml-cuda.cu
src/ggml-metal.m
src/ggml-metal.metal
src/ggml-quants.c
src/ggml-quants.h
src/ggml.c
tests/test-backend-ops.cpp

index c670caa6a3140389b0b246b35f67bae19c11a016..425c9b6ab2d6df4924ade1895be332dacfe50ca0 100644 (file)
@@ -369,6 +369,7 @@ extern "C" {
         GGML_TYPE_I32     = 26,
         GGML_TYPE_I64     = 27,
         GGML_TYPE_F64     = 28,
+        GGML_TYPE_IQ1_M   = 29,
         GGML_TYPE_COUNT,
     };
 
@@ -408,6 +409,7 @@ extern "C" {
         GGML_FTYPE_MOSTLY_IQ3_S   = 20, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ2_S   = 21, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ4_XS  = 22, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ1_M   = 23, // except 1d tensors
     };
 
     // available tensor operations:
index 0257c928cea52110215664d2d07432c20ffaeb6d..517c9bb43b380ca15c7d9959f9cabd37c48589f3 100644 (file)
@@ -377,6 +377,20 @@ typedef struct {
 } block_iq1_s;
 static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
 
+// 1.8125 bpw
+typedef struct {
+    uint8_t  qs[QK_K/8];      // grid index, low 8 bits
+    uint8_t  qh[QK_K/16];     // grid index, high 3 bits + grid shift bit (for two groups of 8)
+    uint8_t  scales[QK_K/32]; // 4-bit block scales
+} block_iq1_m;
+static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
+
+// Used by IQ1_M quants
+typedef union {
+    ggml_half f16;
+    uint16_t  u16;
+} iq1m_scale_t;
+
 // Non-linear quants
 #define QK4_NL 32
 typedef struct {
@@ -1050,6 +1064,7 @@ GGML_TABLE_END()
 
 #define NGRID_IQ1S 2048
 #define IQ1S_DELTA 0.125f
+#define IQ1M_DELTA 0.125f
 #if defined(GGML_COMMON_IMPL_C)
 GGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ1S)
     0xffffffffffffffff, 0xffffffffffffff01, 0xffffffffffff0000, 0xffffffffffff01ff,
index 4f50c9f9fc632051ca55999ed852fc80c6c9a6de..48232b6e18d6ce98f6c57dabdafc44297df9ff7b 100644 (file)
@@ -615,6 +615,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
         case GGML_TYPE_IQ2_S:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
         case GGML_TYPE_IQ4_NL:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
@@ -643,6 +644,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
         case GGML_TYPE_IQ2_S:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
         case GGML_TYPE_IQ4_NL:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
@@ -2560,7 +2562,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 ggml_type a_type = a->type;
                 if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
                     a_type == GGML_TYPE_IQ1_S   || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S   ||
-                    a_type == GGML_TYPE_IQ2_S   || a_type == GGML_TYPE_IQ4_XS) {
+                    a_type == GGML_TYPE_IQ1_M   || a_type == GGML_TYPE_IQ2_S  || a_type == GGML_TYPE_IQ4_XS) {
                     if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
                         return false;
                     }
index 416b24532f0d0efeb293d648d45e5726c220ac45..cbe22aa3792b460d376a69f26fa8ecd8f1a4f007 100644 (file)
@@ -64,6 +64,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
@@ -91,6 +92,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
@@ -114,6 +116,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
@@ -134,6 +137,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
@@ -154,6 +158,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_F32,
@@ -490,6 +495,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,            get_rows_iq3_s,         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,            get_rows_iq2_s,         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,            get_rows_iq1_s,         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,            get_rows_iq1_m,         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,           get_rows_iq4_nl,        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,           get_rows_iq4_xs,        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,              get_rows_i32,           true);
@@ -517,6 +523,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,          mul_mv_iq3_s_f32,       ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,          mul_mv_iq2_s_f32,       ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,          mul_mv_iq1_s_f32,       ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,          mul_mv_iq1_m_f32,       ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,         mul_mv_iq4_nl_f32,      ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,         mul_mv_iq4_xs_f32,      ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,         mul_mv_id_f32_f32,      ctx->support_simdgroup_reduction);
@@ -540,6 +547,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,       mul_mv_id_iq3_s_f32,    ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,       mul_mv_id_iq2_s_f32,    ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,       mul_mv_id_iq1_s_f32,    ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,       mul_mv_id_iq1_m_f32,    ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,      mul_mv_id_iq4_nl_f32,   ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,      mul_mv_id_iq4_xs_f32,   ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,            mul_mm_f32_f32,         ctx->support_simdgroup_mm);
@@ -560,6 +568,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,          mul_mm_iq3_s_f32,       ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,          mul_mm_iq2_s_f32,       ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,          mul_mm_iq1_s_f32,       ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,          mul_mm_iq1_m_f32,       ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,         mul_mm_iq4_nl_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,         mul_mm_iq4_xs_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,         mul_mm_id_f32_f32,      ctx->support_simdgroup_mm);
@@ -580,6 +589,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,       mul_mm_id_iq3_s_f32,    ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,       mul_mm_id_iq2_s_f32,    ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,       mul_mm_id_iq1_s_f32,    ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,       mul_mm_id_iq1_m_f32,    ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,      mul_mm_id_iq4_nl_f32,   ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,      mul_mm_id_iq4_xs_f32,   ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                  rope_f32,               true);
@@ -1421,6 +1431,7 @@ static enum ggml_status ggml_metal_graph_compute(
                                 case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32  ].pipeline; break;
+                                case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
                                 case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
                                 default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
@@ -1575,6 +1586,12 @@ static enum ggml_status ggml_metal_graph_compute(
                                         nth1 = 16;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
                                     } break;
+                                case GGML_TYPE_IQ1_M:
+                                    {
+                                        nth0 = 4;
+                                        nth1 = 16;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
+                                    } break;
                                 case GGML_TYPE_IQ4_NL:
                                     {
                                         nth0 = 4;
@@ -1619,9 +1636,9 @@ static enum ggml_status ggml_metal_graph_compute(
                             [encoder setBytes:&r2   length:sizeof(r2)   atIndex:17];
                             [encoder setBytes:&r3   length:sizeof(r3)   atIndex:18];
 
-                            if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1  ||
-                                src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0 ||
-                                src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
+                            if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
+                                src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
+                                src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1743,6 +1760,7 @@ static enum ggml_status ggml_metal_graph_compute(
                                 case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32  ].pipeline; break;
+                                case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
                                 case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
                                 default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
@@ -1900,6 +1918,12 @@ static enum ggml_status ggml_metal_graph_compute(
                                         nth1 = 16;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
                                     } break;
+                                case GGML_TYPE_IQ1_M:
+                                    {
+                                        nth0 = 4;
+                                        nth1 = 16;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
+                                    } break;
                                 case GGML_TYPE_IQ4_NL:
                                     {
                                         nth0 = 4;
@@ -1960,9 +1984,9 @@ static enum ggml_status ggml_metal_graph_compute(
                                 [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
                             }
 
-                            if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1  ||
-                                src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1  || src2t == GGML_TYPE_Q8_0 ||
-                                src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
+                            if (src2t == GGML_TYPE_Q4_0  || src2t == GGML_TYPE_Q4_1  || src2t == GGML_TYPE_Q5_0 ||
+                                src2t == GGML_TYPE_Q5_1  || src2t == GGML_TYPE_Q8_0  || src2t == GGML_TYPE_Q2_K ||
+                                src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -2024,6 +2048,7 @@ static enum ggml_status ggml_metal_graph_compute(
                             case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S  ].pipeline; break;
                             case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S  ].pipeline; break;
                             case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S  ].pipeline; break;
+                            case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M  ].pipeline; break;
                             case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
                             case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
                             case GGML_TYPE_I32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32    ].pipeline; break;
index 748f0acef27b1c379d12547c6c4134f8e00aada3..e8083734ca4dfa069cd57c2e0e45e79a7c2c2c42 100644 (file)
@@ -4456,6 +4456,104 @@ void kernel_mul_mv_iq1_s_f32_impl(
     }
 }
 
+void kernel_mul_mv_iq1_m_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
+    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    iq1m_scale_t scale;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        float4 sumy = {0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq1_m * xr = x + ibl;
+        device const uint8_t  * qs = xr->qs + 4 * ib;
+        device const uint8_t  * qh = xr->qh + 2 * ib;
+        device const uint16_t * sc = (device const uint16_t *)xr->scales;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
+            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
+
+            float2 sum = {0.f};
+            for (int j = 0; j < 4; ++j) {
+                sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+                        + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
+                sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+                        + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
+            }
+            const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+            const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+            sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
+                                             (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
+
+            sc += nb*sizeof(block_iq1_m)/2;
+            qs += nb*sizeof(block_iq1_m);
+            qh += nb*sizeof(block_iq1_m);
+        }
+
+        y4 += 32 * 32;
+    }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+        }
+    }
+}
+
 void kernel_mul_mv_iq4_nl_f32_impl(
         device const  void * src0,
         device const float * src1,
@@ -4673,6 +4771,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
     kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
 }
 
+[[host_name("kernel_mul_mv_iq1_m_f32")]]
+kernel void kernel_mul_mv_iq1_m_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
 [[host_name("kernel_mul_mv_iq4_nl_f32")]]
 kernel void kernel_mul_mv_iq4_nl_f32(
         device const  void * src0,
@@ -5146,6 +5272,30 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
     }
 }
 
+template <typename type4x4>
+void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    iq1m_scale_t scale;
+    device const uint16_t * sc = (device const uint16_t *)xb->scales;
+    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+    const float d = scale.f16;
+    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint8_t * qh = xb->qh + 2*ib32 + il;
+    const float dl  = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
+    const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+    const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
+        reg[1][i] = dl * (grid1[i] >>  4) + ml1;
+        reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
+        reg[3][i] = dl * (grid2[i] >>  4) + ml2;
+    }
+}
+
 template <typename type4x4>
 void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
     device const uint16_t * q4 = (device const uint16_t *)xb->qs;
@@ -5730,6 +5880,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
 template [[host_name("kernel_get_rows_iq3_s")]]   kernel get_rows_t kernel_get_rows<block_iq3_s,   QK_NL, dequantize_iq3_s>;
 template [[host_name("kernel_get_rows_iq2_s")]]   kernel get_rows_t kernel_get_rows<block_iq2_s,   QK_NL, dequantize_iq2_s>;
 template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_t kernel_get_rows<block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_get_rows_iq1_m")]]   kernel get_rows_t kernel_get_rows<block_iq1_m,   QK_NL, dequantize_iq1_m>;
 template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_t kernel_get_rows<block_iq4_nl,  2,     dequantize_iq4_nl>;
 #if QK_K == 64
 template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_t kernel_get_rows<block_iq4_xs,  2,     dequantize_iq4_xs>;
@@ -5778,6 +5929,7 @@ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_m
 template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq3_s,   QK_NL, dequantize_iq3_s>;
 template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq2_s,   QK_NL, dequantize_iq2_s>;
 template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_mul_mm_iq1_m_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq1_m,   QK_NL, dequantize_iq1_m>;
 template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_nl,  2,     dequantize_iq4_nl>;
 #if QK_K == 64
 template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_nl,  2,     dequantize_iq4_xs>;
@@ -5838,6 +5990,7 @@ template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel
 template [[host_name("kernel_mul_mm_id_iq3_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s,   QK_NL, dequantize_iq3_s>;
 template [[host_name("kernel_mul_mm_id_iq2_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s,   QK_NL, dequantize_iq2_s>;
 template [[host_name("kernel_mul_mm_id_iq1_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_mul_mm_id_iq1_m_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m,   QK_NL, dequantize_iq1_m>;
 template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl,  2,     dequantize_iq4_nl>;
 #if QK_K == 64
 template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs,  2,     dequantize_iq4_xs>;
@@ -7005,6 +7158,69 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
         sgitg);
 }
 
+[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
+kernel void kernel_mul_mv_id_iq1_m_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         float * dst,
+        constant    uint64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant    uint64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_iq1_m_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        dst + bid*ne0,
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
 [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
 kernel void kernel_mul_mv_id_iq4_nl_f32(
         device const    char * ids,
index f26798accdb9a5f4564a007d3873edb03b2bbed6..f717e616e6a2c3759308b8a97f3d6d7643b9c4e5 100644 (file)
@@ -3474,6 +3474,54 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in
     }
 }
 
+void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    float delta[4];
+    uint16_t idx[4];
+
+    iq1m_scale_t scale;
+
+    for (int i = 0; i < nb; i++) {
+
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+        const float d = GGML_FP16_TO_FP32(scale.f16);
+        const uint8_t * qs = x[i].qs;
+        const uint8_t * qh = x[i].qh;
+
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1);
+            const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1);
+            idx[0] = qs[0] | ((qh[0] << 8) & 0x700);
+            idx[1] = qs[1] | ((qh[0] << 4) & 0x700);
+            idx[2] = qs[2] | ((qh[1] << 8) & 0x700);
+            idx[3] = qs[3] | ((qh[1] << 4) & 0x700);
+            delta[0] = qh[0] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA;
+            delta[1] = qh[0] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA;
+            delta[2] = qh[1] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA;
+            delta[3] = qh[1] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA;
+            for (int l = 0; l < 2; ++l) {
+                const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
+                for (int j = 0; j < 8; ++j) {
+                    y[j] = dl1 * (grid[j] + delta[l]);
+                }
+                y += 8;
+            }
+            for (int l = 2; l < 4; ++l) {
+                const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
+                for (int j = 0; j < 8; ++j) {
+                    y[j] = dl2 * (grid[j] + delta[l]);
+                }
+                y += 8;
+            }
+            qs += 4;
+            qh += 2;
+        }
+    }
+}
+
 static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
 
 void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) {
@@ -9695,6 +9743,206 @@ void ggml_vec_dot_iq1_s_q8_K  (int n, float * restrict s, size_t bs, const void
 #endif
 }
 
+void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq1_m * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+    iq1m_scale_t scale;
+
+#if defined __ARM_NEON
+
+    const int32x4_t mask  = vdupq_n_s32(0x7);
+    const int32x4_t mone  = vdupq_n_s32(1);
+    const int32x4_t mzero = vdupq_n_s32(0);
+
+    ggml_int8x16x4_t deltas;
+    deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1));
+    deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1));
+    deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1));
+    deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1));
+
+    ggml_int8x16x4_t q1b;
+    ggml_int8x16x4_t q8b;
+
+    uint32_t aux32;
+    const uint8_t * aux8 = (const uint8_t *)&aux32;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        int32x4_t sumi1 = mzero;
+        int32x4_t sumi2 = mzero;
+
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+
+            q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700)))));
+            q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700)))));
+            q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700)))));
+            q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700)))));
+
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1]));
+            const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));
+            const int32x4_t p12 = vpaddq_s32(p1, p2);
+
+            const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that
+            aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202);
+
+            const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1]));
+            const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
+            const int32x4_t p34 = vpaddq_s32(p3, p4);
+
+            int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
+            scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
+
+            sumi1 = vmlaq_s32(sumi1, scales_4, p12);
+            sumi2 = vmlaq_s32(sumi2, scales_4, p34);
+
+            qs += 8; qh += 4;
+
+        }
+
+        sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i mask = _mm256_set1_epi16(0x7);
+    const __m256i mone = _mm256_set1_epi16(1);
+
+    __m256 accum1 = _mm256_setzero_ps();
+    __m256 accum2 = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m256i q1b_1 = _mm256_set_epi64x(
+                    iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
+                    iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
+            );
+            const __m256i q1b_2 = _mm256_set_epi64x(
+                    iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
+                    iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
+            );
+            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
+            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
+
+            const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+            const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+
+            const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
+            const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
+            __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
+            __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
+            scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
+            scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
+            const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
+            const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
+            const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
+            const __m256i p4 = _mm256_madd_epi16(dot4, scale2);
+
+            sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2));
+            sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4));
+
+            qs += 8; qh += 4;
+        }
+
+        const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
+        accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
+        accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
+
+    }
+
+    *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
+
+#else
+
+    int sum1[2], sum2[2], delta[4];
+
+    float sumf = 0;
+    for (int i = 0; i < nb; i++) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            delta[0] = qh[0] & 0x08 ? -1 : 1;
+            delta[1] = qh[0] & 0x80 ? -1 : 1;
+            delta[2] = qh[1] & 0x08 ? -1 : 1;
+            delta[3] = qh[1] & 0x80 ? -1 : 1;
+            sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0;
+            for (int l = 0; l < 4; ++l) {
+                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700)));
+                int lsum1 = 0, lsum2 = 0;
+                for (int j = 0; j < 8; ++j) {
+                    lsum1 += q8[j] * grid[j];
+                    lsum2 += q8[j];
+                }
+                q8 += 8;
+                sum1[l/2] += lsum1;
+                sum2[l/2] += lsum2*delta[l];
+            }
+            const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
+            const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
+            sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
+            sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
+            qs += 4;
+            qh += 2;
+        }
+
+        sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
+    }
+
+    *s = sumf;
+
+#endif
+}
+
 void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(nrc == 1);
     UNUSED(nrc);
@@ -9938,17 +10186,17 @@ static iq2_entry_t iq2_data[4] = {
 };
 
 static inline int iq2_data_index(enum ggml_type type) {
-    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S);
+    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
     return type == GGML_TYPE_IQ2_XXS ? 0 :
            type == GGML_TYPE_IQ2_XS  ? 1 :
-           type == GGML_TYPE_IQ1_S   ? 2 : 3;
+           type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 2 : 3;
 }
 
 static inline int iq2_grid_size(enum ggml_type type) {
-    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S);
+    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
     return type == GGML_TYPE_IQ2_XXS ? 256 :
            type == GGML_TYPE_IQ2_XS  ? 512 :
-           type == GGML_TYPE_IQ1_S   ? NGRID_IQ1S : 1024;
+           type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? NGRID_IQ1S : 1024;
 }
 
 static int iq2_compare_func(const void * left, const void * right) {
@@ -10214,10 +10462,10 @@ void iq2xs_init_impl(enum ggml_type type) {
 
     const int kmap_size = 43692;
     //const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;
-    const int nwant = type == GGML_TYPE_IQ1_S ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2;
+    const int nwant = type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2;
     const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
                              type == GGML_TYPE_IQ2_XS  ? kgrid_2bit_512 :
-                             type == GGML_TYPE_IQ1_S   ? kgrid_1bit_2048 : kgrid_2bit_1024;
+                             type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? kgrid_1bit_2048 : kgrid_2bit_1024;
     uint64_t * kgrid_q2xs;
     int      * kmap_q2xs;
     uint16_t * kneighbors_q2xs;
@@ -10314,7 +10562,7 @@ void iq2xs_init_impl(enum ggml_type type) {
 }
 
 void iq2xs_free_impl(enum ggml_type type) {
-    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S);
+    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
     const int gindex = iq2_data_index(type);
     if (iq2_data[gindex].grid) {
         free(iq2_data[gindex].grid);       iq2_data[gindex].grid = NULL;
@@ -11520,7 +11768,16 @@ static int iq1_sort_helper(const void * left, const void * right) {
 }
 
 #define IQ1S_BLOCK_SIZE 32
-static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+#define IQ1M_BLOCK_SIZE 16
+static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights,
+        float    * scales,
+        float    * weight,
+        float    * sumx,
+        float    * sumw,
+        float    * pairs,
+        int8_t   * L,
+        uint16_t * index,
+        int8_t   * shifts) {
 
     const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
 
@@ -11534,22 +11791,17 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
     GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
     GGML_ASSERT(n%QK_K == 0);
 
+    block_iq1_s * y = vy;
+
     const int nbl = n/QK_K;
 
-    block_iq1_s * y = vy;
+    const int block_size = IQ1S_BLOCK_SIZE;
 
     const float x_p[3] = {-1 + IQ1S_DELTA,  IQ1S_DELTA, 1 + IQ1S_DELTA};
     const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA};
 
-    float  scales[QK_K/IQ1S_BLOCK_SIZE];
-    float  weight[IQ1S_BLOCK_SIZE];
-    int8_t L[IQ1S_BLOCK_SIZE];
-    float  sumx[IQ1S_BLOCK_SIZE+1];
-    float  sumw[IQ1S_BLOCK_SIZE+1];
-    float  pairs[2*IQ1S_BLOCK_SIZE];
+
     int * idx = (int *)(pairs + 1);
-    uint16_t index[IQ1S_BLOCK_SIZE/8];
-    int8_t shifts[QK_K/IQ1S_BLOCK_SIZE];
 
     for (int ibl = 0; ibl < nbl; ++ibl) {
 
@@ -11564,15 +11816,15 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
         for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
         float sigma2 = 2*sumx2/QK_K;
 
-        for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) {
-            const float * xb = xbl + IQ1S_BLOCK_SIZE*ib;
-            const float * qw = quant_weights + QK_K*ibl + IQ1S_BLOCK_SIZE*ib;
-            for (int i = 0; i < IQ1S_BLOCK_SIZE; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+        for (int ib = 0; ib < QK_K/block_size; ++ib) {
+            const float * xb = xbl + block_size*ib;
+            const float * qw = quant_weights + QK_K*ibl + block_size*ib;
+            for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
             float max = fabsf(xb[0]);
-            for (int i = 1; i < IQ1S_BLOCK_SIZE; ++i) max = MAX(max, fabsf(xb[i]));
+            for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
             if (!max) {
                 scales[ib] = 0;
-                memset(L, 1, IQ1S_BLOCK_SIZE);
+                memset(L, 1, block_size);
                 continue;
             }
             // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
@@ -11581,14 +11833,14 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
             // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
             // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
             // for each possible and score for each split.
-            for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) {
+            for (int j = 0; j < block_size; ++j) {
                 pairs[2*j] = xb[j];
                 idx[2*j] = j;
             }
-            qsort(pairs, IQ1S_BLOCK_SIZE, 2*sizeof(float), iq1_sort_helper);
+            qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
             {
                 sumx[0] = sumw[0] = 0;
-                for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) {
+                for (int j = 0; j < block_size; ++j) {
                     int i = idx[2*j];
                     sumx[j+1] = sumx[j] + weight[i]*xb[i];
                     sumw[j+1] = sumw[j] + weight[i];
@@ -11596,16 +11848,16 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
             }
             float best_score = 0, scale = max;
             int besti1 = -1, besti2 = -1, best_shift = 0;
-            for (int i1 = 0; i1 <= IQ1S_BLOCK_SIZE; ++i1) {
-                for (int i2 = i1; i2 <= IQ1S_BLOCK_SIZE; ++i2) {
-                    float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[IQ1S_BLOCK_SIZE] - sumx[i2])*x_p[2];
-                    float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[IQ1S_BLOCK_SIZE] - sumw[i2])*x_p[2]*x_p[2];
+            for (int i1 = 0; i1 <= block_size; ++i1) {
+                for (int i2 = i1; i2 <= block_size; ++i2) {
+                    float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2];
+                    float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2];
                     if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
                         scale = sumqx/sumq2; best_score = scale*sumqx;
                         besti1 = i1; besti2 = i2; best_shift = 1;
                     }
-                    sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[IQ1S_BLOCK_SIZE] - sumx[i2])*x_m[2];
-                    sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[IQ1S_BLOCK_SIZE] - sumw[i2])*x_m[2]*x_m[2];
+                    sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2];
+                    sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2];
                     if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
                         scale = sumqx/sumq2; best_score = scale*sumqx;
                         besti1 = i1; besti2 = i2; best_shift = -1;
@@ -11615,14 +11867,14 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
             GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0);
             for (int j =      0; j < besti1; ++j) L[idx[2*j]] = 0;
             for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
-            for (int j = besti2; j < IQ1S_BLOCK_SIZE; ++j) L[idx[2*j]] = 2;
+            for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
             if (scale < 0) {
-                for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) L[j] = 2 - L[j];
+                for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j];
                 scale = -scale; best_shift = -best_shift;
             }
             bool all_on_grid = true;
             const float * xx = best_shift == 1 ? x_p : x_m;
-            for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
+            for (int k = 0; k < block_size/8; ++k) {
                 uint16_t u = 0;
                 for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
                 int grid_index = kmap_q2xs[u];
@@ -11636,7 +11888,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
             }
             if (!all_on_grid) {
                 float sumqx = 0, sumq2 = 0;
-                for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
+                for (int k = 0; k < block_size/8; ++k) {
                     const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
                     for (int j = 0; j < 8; ++j) {
                         float w = weight[8*k + j];
@@ -11648,8 +11900,8 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
                 if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2;
             }
             uint16_t h = 0;
-            for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
-                y[ibl].qs[(IQ1S_BLOCK_SIZE/8)*ib + k] = index[k] & 255;
+            for (int k = 0; k < block_size/8; ++k) {
+                y[ibl].qs[(block_size/8)*ib + k] = index[k] & 255;
                 h |= (index[k] >> 8) << 3*k;
             }
             y[ibl].qh[ib] = h;
@@ -11660,14 +11912,13 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
         }
 
         if (!max_scale) {
-            memset(y[ibl].qs, 0, QK_K/8);
             continue;
         }
 
         float d = max_scale/15;
-        y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.085f is another fudge factor. Don't ask me why it is needed.
+        y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.125f is another fudge factor. Don't ask me why it is needed.
         float id = 1/d;
-        for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) {
+        for (int ib = 0; ib < QK_K/block_size; ++ib) {
             int l = nearest_int(0.5f*(id*scales[ib]-1));
             l = MAX(0, MIN(7, l));
             if (shifts[ib] == -1) l |= 8;
@@ -11678,16 +11929,292 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
 
 size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
     GGML_ASSERT(n_per_row%QK_K == 0);
+    float  scales[QK_K/IQ1S_BLOCK_SIZE];
+    float  weight[IQ1S_BLOCK_SIZE];
+    int8_t L[IQ1S_BLOCK_SIZE];
+    float  sumx[IQ1S_BLOCK_SIZE+1];
+    float  sumw[IQ1S_BLOCK_SIZE+1];
+    float  pairs[2*IQ1S_BLOCK_SIZE];
+    uint16_t index[IQ1S_BLOCK_SIZE/8];
+    int8_t shifts[QK_K/IQ1S_BLOCK_SIZE];
     int nblock = n_per_row/QK_K;
     char * qrow = (char *)dst;
     for (int row = 0; row < nrow; ++row) {
-        quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights);
+        quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts);
         src += n_per_row;
         qrow += nblock*sizeof(block_iq1_s);
     }
     return nrow * nblock * sizeof(block_iq1_s);
 }
 
+static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights,
+        float    * scales,
+        float    * weight,
+        float    * pairs,
+        int8_t   * L,
+        uint16_t * index,
+        int8_t   * shifts) {
+
+    const int gindex = iq2_data_index(GGML_TYPE_IQ1_M);
+
+    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
+    const int      * kmap_q2xs       = iq2_data[gindex].map;
+    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+    //GGML_ASSERT(quant_weights   && "missing quantization weights");
+    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(n%QK_K == 0);
+
+    block_iq1_m * y = vy;
+
+    const int nbl = n/QK_K;
+
+    const int block_size = IQ1M_BLOCK_SIZE;
+
+    const float x_p[3] = {-1 + IQ1M_DELTA,  IQ1M_DELTA, 1 + IQ1M_DELTA};
+    const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA};
+    const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88};
+
+    int * idx = (int *)(pairs + 1);
+
+    float sumqx[4], sumq2[4];
+
+    iq1m_scale_t s;
+    const float * xx;
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        //y[ibl].d = GGML_FP32_TO_FP16(0.f);
+        memset(y[ibl].qs, 0, QK_K/8);
+        memset(y[ibl].qh, 0, QK_K/16);
+        memset(y[ibl].scales, 0, QK_K/32);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = 2*sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/block_size; ++ib) {
+            const float * xb = xbl + block_size*ib;
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K*ibl + block_size*ib;
+                for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            } else {
+                for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
+            }
+            float max = fabsf(xb[0]);
+            for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
+            if (!max) {
+                scales[ib] = 0;
+                memset(L, 1, block_size);
+                continue;
+            }
+            // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
+            // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
+            // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
+            // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
+            // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
+            // for each possible and score for each split.
+            for (int j = 0; j < block_size; ++j) {
+                pairs[2*j] = xb[j];
+                idx[2*j] = j;
+            }
+            qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
+            float best_score = 0, scale = max;
+            int besti1 = -1, besti2 = -1, best_k = -1;
+            // 0: +, +
+            // 1: +, -
+            // 2: -, +
+            // 3: -, -
+            for (int i1 = 0; i1 <= block_size; ++i1) {
+                for (int i2 = i1; i2 <= block_size; ++i2) {
+                    memset(sumqx, 0, 4*sizeof(float));
+                    memset(sumq2, 0, 4*sizeof(float));
+                    for (int j = 0; j < i1; ++j) {
+                        int i = idx[2*j];
+                        if (i < block_size/2) {
+                            sumqx[0] += weight[i]*x_p[0]*xb[i];
+                            sumqx[1] += weight[i]*x_p[0]*xb[i];
+                            sumqx[2] += weight[i]*x_m[0]*xb[i];
+                            sumqx[3] += weight[i]*x_m[0]*xb[i];
+                            sumq2[0] += weight[i]*x_p[0]*x_p[0];
+                            sumq2[1] += weight[i]*x_p[0]*x_p[0];
+                            sumq2[2] += weight[i]*x_m[0]*x_m[0];
+                            sumq2[3] += weight[i]*x_m[0]*x_m[0];
+                        } else {
+                            sumqx[0] += weight[i]*x_p[0]*xb[i];
+                            sumqx[2] += weight[i]*x_p[0]*xb[i];
+                            sumqx[1] += weight[i]*x_m[0]*xb[i];
+                            sumqx[3] += weight[i]*x_m[0]*xb[i];
+                            sumq2[0] += weight[i]*x_p[0]*x_p[0];
+                            sumq2[2] += weight[i]*x_p[0]*x_p[0];
+                            sumq2[1] += weight[i]*x_m[0]*x_m[0];
+                            sumq2[3] += weight[i]*x_m[0]*x_m[0];
+                        }
+                    }
+                    for (int j = i1; j < i2; ++j) {
+                        int i = idx[2*j];
+                        if (i < block_size/2) {
+                            sumqx[0] += weight[i]*x_p[1]*xb[i];
+                            sumqx[1] += weight[i]*x_p[1]*xb[i];
+                            sumqx[2] += weight[i]*x_m[1]*xb[i];
+                            sumqx[3] += weight[i]*x_m[1]*xb[i];
+                            sumq2[0] += weight[i]*x_p[1]*x_p[1];
+                            sumq2[1] += weight[i]*x_p[1]*x_p[1];
+                            sumq2[2] += weight[i]*x_m[1]*x_m[1];
+                            sumq2[3] += weight[i]*x_m[1]*x_m[1];
+                        } else {
+                            sumqx[0] += weight[i]*x_p[1]*xb[i];
+                            sumqx[2] += weight[i]*x_p[1]*xb[i];
+                            sumqx[1] += weight[i]*x_m[1]*xb[i];
+                            sumqx[3] += weight[i]*x_m[1]*xb[i];
+                            sumq2[0] += weight[i]*x_p[1]*x_p[1];
+                            sumq2[2] += weight[i]*x_p[1]*x_p[1];
+                            sumq2[1] += weight[i]*x_m[1]*x_m[1];
+                            sumq2[3] += weight[i]*x_m[1]*x_m[1];
+                        }
+                    }
+                    for (int j = i2; j < block_size; ++j) {
+                        int i = idx[2*j];
+                        if (i < block_size/2) {
+                            sumqx[0] += weight[i]*x_p[2]*xb[i];
+                            sumqx[1] += weight[i]*x_p[2]*xb[i];
+                            sumqx[2] += weight[i]*x_m[2]*xb[i];
+                            sumqx[3] += weight[i]*x_m[2]*xb[i];
+                            sumq2[0] += weight[i]*x_p[2]*x_p[2];
+                            sumq2[1] += weight[i]*x_p[2]*x_p[2];
+                            sumq2[2] += weight[i]*x_m[2]*x_m[2];
+                            sumq2[3] += weight[i]*x_m[2]*x_m[2];
+                        } else {
+                            sumqx[0] += weight[i]*x_p[2]*xb[i];
+                            sumqx[2] += weight[i]*x_p[2]*xb[i];
+                            sumqx[1] += weight[i]*x_m[2]*xb[i];
+                            sumqx[3] += weight[i]*x_m[2]*xb[i];
+                            sumq2[0] += weight[i]*x_p[2]*x_p[2];
+                            sumq2[2] += weight[i]*x_p[2]*x_p[2];
+                            sumq2[1] += weight[i]*x_m[2]*x_m[2];
+                            sumq2[3] += weight[i]*x_m[2]*x_m[2];
+                        }
+                    }
+                    for (int k = 0; k < 4; ++k) {
+                        if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
+                            scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k];
+                            besti1 = i1; besti2 = i2; best_k = k;
+                        }
+                    }
+                }
+            }
+            GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0);
+            for (int j =      0; j < besti1; ++j) L[idx[2*j]] = 0;
+            for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
+            for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
+            if (scale < 0) {
+                for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j];
+                scale = -scale;
+                best_k = best_k == 0 ? 3 : best_k == 1 ? 2 : best_k == 2 ? 1 : 0;
+            }
+            bool all_on_grid = true;
+            for (int k = 0; k < block_size/8; ++k) {
+                if (k == 0) xx = best_k < 2 ? x_p : x_m;
+                else xx = best_k%2 == 0 ? x_p : x_m;
+                uint16_t u = 0;
+                for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
+                int grid_index = kmap_q2xs[u];
+                if (grid_index < 0) {
+                    all_on_grid = false;
+                    const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                    grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S);
+                    GGML_ASSERT(grid_index >= 0);
+                }
+                index[k] = grid_index;
+            }
+            if (!all_on_grid) {
+                float sumqx_f = 0, sumq2_f = 0;
+                for (int k = 0; k < block_size/8; ++k) {
+                    if (k == 0) xx = best_k < 2 ? x_p : x_m;
+                    else xx = best_k%2 == 0 ? x_p : x_m;
+                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
+                    for (int j = 0; j < 8; ++j) {
+                        float w = weight[8*k + j];
+                        float q = xx[(pg[j] - 1)/2];
+                        sumqx_f += w*q*xb[8*k+j];
+                        sumq2_f += w*q*q;
+                    }
+                }
+                if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f;
+            }
+            y[ibl].qs[2*ib + 0] = index[0] & 255;
+            y[ibl].qs[2*ib + 1] = index[1] & 255;
+            y[ibl].qh[ib] = (index[0] >> 8) | ((index[1] >> 8) << 4);
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            shifts[ib] = best_k;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            continue;
+        }
+
+        uint16_t * sc = (uint16_t *)y[ibl].scales;
+        float d = max_scale/15;
+        float id = 1/d;
+        float sumqx_f = 0, sumq2_f = 0;
+        for (int ib = 0; ib < QK_K/block_size; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib+0]-1));
+            l = MAX(0, MIN(7, l));
+            sc[ib/4] |= (l << 3*(ib%4));
+            y[ibl].qh[ib] |= masks[shifts[ib]];
+            const float * xb = xbl + block_size*ib;
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K*ibl + block_size*ib;
+                for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            } else {
+                for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
+            }
+            for (int k = 0; k < block_size/8; ++k) {
+                if (k == 0) xx = shifts[ib] < 2 ? x_p : x_m;
+                else xx = shifts[ib]%2 == 0 ? x_p : x_m;
+                const int8_t * pg = (const int8_t *)(kgrid_q2xs + y[ibl].qs[2*ib+k] + ((y[ibl].qh[ib] << (8 - 4*k)) & 0x700));
+                for (int j = 0; j < 8; ++j) {
+                    float w = weight[8*k + j];
+                    float q = xx[(pg[j] - 1)/2]*(2*l+1);
+                    sumqx_f += w*q*xb[8*k+j];
+                    sumq2_f += w*q*q;
+                }
+            }
+        }
+        if (sumq2_f > 0) d = sumqx_f/sumq2_f;
+        s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
+        sc[0] |= ((s.u16 & 0x000f) << 12);
+        sc[1] |= ((s.u16 & 0x00f0) <<  8);
+        sc[2] |= ((s.u16 & 0x0f00) <<  4);
+        sc[3] |= ((s.u16 & 0xf000) <<  0);
+    }
+}
+
+size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    float  scales[QK_K/IQ1M_BLOCK_SIZE];
+    float  weight[IQ1M_BLOCK_SIZE];
+    int8_t L[IQ1M_BLOCK_SIZE];
+    float  pairs[2*IQ1M_BLOCK_SIZE];
+    uint16_t index[IQ1M_BLOCK_SIZE/8];
+    int8_t shifts[QK_K/IQ1M_BLOCK_SIZE];
+    int nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int row = 0; row < nrow; ++row) {
+        quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq1_m);
+    }
+    return nrow * nblock * sizeof(block_iq1_m);
+}
+
 // ============================ 4-bit non-linear quants
 
 static inline int best_index_int8(int n, const int8_t * val, float x) {
index aa7e54a16e867e8f288b40ad1756c63162b5aa7e..ac1091c3d3b66b8ca083a25b17479870ec94a1b5 100644 (file)
@@ -72,6 +72,7 @@ void dequantize_row_iq2_xs (const block_iq2_xs  * GGML_RESTRICT x, float * GGML_
 void dequantize_row_iq2_s  (const block_iq2_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq1_s  (const block_iq1_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
+void dequantize_row_iq1_m  (const block_iq1_m   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq4_nl (const block_iq4_nl  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq4_xs (const block_iq4_xs  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq3_s  (const block_iq3_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
@@ -94,6 +95,7 @@ void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
 void ggml_vec_dot_iq2_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq1_m_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq3_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -104,6 +106,7 @@ size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
 size_t quantize_iq2_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
 size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
 size_t quantize_iq1_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
+size_t quantize_iq1_m  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
 size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
 size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
 size_t quantize_iq3_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
index 62b8339599642bedb7b7ff737e2e6877d193259a..a86b41c158558820c5921e0ad2c827dea874016e 100644 (file)
@@ -794,6 +794,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_K,
         .nrows                    = 1,
     },
+    [GGML_TYPE_IQ1_M] = {
+        .type_name                = "iq1_m",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq1_m),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq1_m,
+        .from_float               = NULL,
+        .from_float_reference     = NULL,
+        .vec_dot                  = ggml_vec_dot_iq1_m_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
     [GGML_TYPE_IQ4_NL] = {
         .type_name                = "iq4_nl",
         .blck_size                = QK4_NL,
@@ -2539,6 +2551,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
         case GGML_FTYPE_MOSTLY_IQ2_XS:        wtype = GGML_TYPE_IQ2_XS;   break;
         case GGML_FTYPE_MOSTLY_IQ3_XXS:       wtype = GGML_TYPE_IQ3_XXS;  break;
         case GGML_FTYPE_MOSTLY_IQ1_S:         wtype = GGML_TYPE_IQ1_S;    break;
+        case GGML_FTYPE_MOSTLY_IQ1_M:         wtype = GGML_TYPE_IQ1_M;    break;
         case GGML_FTYPE_MOSTLY_IQ4_NL:        wtype = GGML_TYPE_IQ4_NL;   break;
         case GGML_FTYPE_MOSTLY_IQ4_XS:        wtype = GGML_TYPE_IQ4_XS;   break;
         case GGML_FTYPE_MOSTLY_IQ3_S:         wtype = GGML_TYPE_IQ3_S;    break;
@@ -8135,6 +8148,7 @@ static void ggml_compute_forward_add(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
         case GGML_TYPE_IQ4_NL:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
@@ -8417,6 +8431,7 @@ static void ggml_compute_forward_add1(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
         case GGML_TYPE_IQ4_NL:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
@@ -8544,6 +8559,7 @@ static void ggml_compute_forward_acc(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
         case GGML_TYPE_IQ4_NL:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
@@ -11447,6 +11463,7 @@ static void ggml_compute_forward_out_prod(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
         case GGML_TYPE_IQ4_NL:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
@@ -11638,6 +11655,7 @@ static void ggml_compute_forward_set(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
         case GGML_TYPE_IQ4_NL:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
@@ -11861,6 +11879,7 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
         case GGML_TYPE_IQ4_NL:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
@@ -12564,6 +12583,7 @@ static void ggml_compute_forward_alibi(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
         case GGML_TYPE_IQ4_NL:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
@@ -12652,6 +12672,7 @@ static void ggml_compute_forward_clamp(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
         case GGML_TYPE_IQ4_NL:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
@@ -20306,7 +20327,8 @@ void ggml_quantize_init(enum ggml_type type) {
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_IQ1_S:   iq2xs_init_impl(type); break;
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:   iq2xs_init_impl(type); break;
         case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
         case GGML_TYPE_IQ3_S:   iq3xs_init_impl(512); break;
         default: // nothing
@@ -20331,7 +20353,8 @@ bool ggml_quantize_requires_imatrix(enum ggml_type type) {
     return
         type == GGML_TYPE_IQ2_XXS ||
         type == GGML_TYPE_IQ2_XS  ||
-        type == GGML_TYPE_IQ1_S;
+        type == GGML_TYPE_IQ1_S;//   ||
+        //type == GGML_TYPE_IQ1_M;
 }
 
 size_t ggml_quantize_chunk(
@@ -20375,6 +20398,7 @@ size_t ggml_quantize_chunk(
         case GGML_TYPE_IQ3_S:   result = quantize_iq3_s  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_IQ2_S:   result = quantize_iq2_s  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_IQ1_S:   result = quantize_iq1_s  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_IQ1_M:   result = quantize_iq1_m  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_IQ4_NL:  result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
 #if QK_K == 64
         case GGML_TYPE_IQ4_XS:  result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
index 1998e1cbc4703b43e1620e52c3bf4e6c6df5c3ad..5dfea5662eb0b5a9267d7e5a2a3a2a05e1803423 100644 (file)
@@ -1960,7 +1960,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
         GGML_TYPE_Q6_K,
         GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
-        GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S,
+        GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
         GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
     };