]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
IQ4_XS: a 4.25 bpw quantization (llama/5747)
authorKawrakow <redacted>
Tue, 27 Feb 2024 14:34:24 +0000 (16:34 +0200)
committerGeorgi Gerganov <redacted>
Wed, 28 Feb 2024 11:00:29 +0000 (13:00 +0200)
* Try IQ4_NL with blocks of 64 - does not look good

* iq4_xs: go to super-blocks of 256 and 6-bit scales for blocks of 32

* iq4_xs: CUDA works - 133.2 t/s

* iq4_xs: AVX2 dot product

* iq4_xs: ARM_NEON dot product

* iq4_nl: Metal implementation

As usual, Metal / Apple Silicon don't like my quants.

* iq3_xs: minor fix

* iq4_xs: shrink by using IQ3_S for attn_k and attn_q

* iq4_xs: revert using IQ3_S for attn_k and attn_v

PPL vs size is good, but CPU performance suffers: on M2 Max
TG-128 drops to 21.7 t/s from 28.8, and on a Ryzen-7950X
to 14.5 t/s from 15.8 t/s. On CUDA we have 135 t/s when
using IQ3_S vs 133 t/s with pure IQ4_XS.

* Fix CI

* iq4_xs: Added forgotten check for 256 divisibility

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml-quants.c
ggml-quants.h
ggml.c
ggml.h

index 1242c0410df497efb7e822cd875a78aef86276a3..53b3ea2998099ef33ec7e11edf0dff52d2060a74 100644 (file)
@@ -571,6 +571,18 @@ typedef struct {
 } block_iq4_nl;
 static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
 
+// QR4_XS = 8 is very slightly faster than QR4_XS = 4
+#define QR4_XS 8
+#define QI4_XS (QK_K / (4*QR4_XS))
+typedef struct {
+    half d;
+    uint16_t scales_h;
+    uint8_t  scales_l[QK_K/64];
+    uint8_t  qs[QK_K/2];
+} block_iq4_xs;
+static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
+
+
 #define WARP_SIZE 32
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
@@ -2427,6 +2439,25 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
 
 }
 
+template<typename dst_t>
+static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int i   = blockIdx.x;
+    const block_iq4_xs * x = (const block_iq4_xs *)vx;
+
+    const int tid = threadIdx.x;
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+    const uint8_t  * q4 = x[i].qs + 16*ib + 4*il;
+    const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
+    for (int j = 0; j < 4; ++j) {
+        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
+    }
+
+}
+
 static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -5286,6 +5317,76 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
     return d * (sumi1 + sumi2);
 }
 
+static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+#if QK_K == 256
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+
+    const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
+    const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
+
+    //// iqs is 0...7
+    //const int ib64 = iqs/2;
+    //const int il = iqs%2;
+    //const int32_t  * q8_1 = (const int *)bq8_1[2*ib64+0].qs + 2*il;
+    //const int32_t  * q8_2 = (const int *)bq8_1[2*ib64+1].qs + 2*il;
+    //const uint32_t * q4_1 = (const uint32_t *)bq4->qs + 8*ib64 + 2*il;
+    //const uint32_t * q4_2 = q4_1 + 4;
+    //const int8_t ls1 = (bq4->scales_l[ib64] & 0xf) | (((bq4->scales_h >> (4*ib64+0)) & 3) << 4);
+    //const int8_t ls2 = (bq4->scales_l[ib64] >>  4) | (((bq4->scales_h >> (4*ib64+2)) & 3) << 4);
+    //const float d1 = (float)bq4->d * (ls1 - 32) * __low2float(bq8_1[2*ib64+0].ds);
+    //const float d2 = (float)bq4->d * (ls2 - 32) * __low2float(bq8_1[2*ib64+1].ds);
+    //int v1, v2;
+    //int sumi1 = 0, sumi2 = 0;
+    //for (int j = 0; j < 2; ++j) {
+    //    get_int_from_table_16(q4_1[j], values, v1, v2);
+    //    sumi1 = __dp4a(v2, q8_1[j+4], __dp4a(v1, q8_1[j+0], sumi1));
+    //    get_int_from_table_16(q4_2[j], values, v1, v2);
+    //    sumi2 = __dp4a(v2, q8_2[j+4], __dp4a(v1, q8_2[j+0], sumi2));
+    //}
+    //return d1 * sumi1 + d2 * sumi2;
+
+    // iqs is 0...7
+    const int ib32 = iqs;
+    const int32_t  * q8 = (const int *)bq8_1[ib32].qs;
+    const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
+    const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
+    const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds);
+    int v1, v2;
+    int sumi1 = 0, sumi2 = 0;
+    for (int j = 0; j < 4; ++j) {
+        get_int_from_table_16(q4[j], values, v1, v2);
+        sumi1 = __dp4a(v1, q8[j+0], sumi1);
+        sumi2 = __dp4a(v2, q8[j+4], sumi2);
+    }
+    return d * (sumi1 + sumi2);
+
+    //// iqs is 0...15
+    //const int ib32 = iqs/2;
+    //const int il = iqs%2;
+    //const int32_t  * q8 = (const int *)bq8_1[ib32].qs + 2*il;
+    //const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32 + 2*il;
+    //const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
+    //const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds);
+    //int v1, v2;
+    //int sumi1 = 0, sumi2 = 0;
+    //for (int j = 0; j < 2; ++j) {
+    //    get_int_from_table_16(q4[j], values, v1, v2);
+    //    sumi1 = __dp4a(v1, q8[j+0], sumi1);
+    //    sumi2 = __dp4a(v2, q8[j+4], sumi2);
+    //}
+    //return d * (sumi1 + sumi2);
+#else
+    assert(false);
+    return 0.f;
+#endif
+#else
+    assert(false);
+    return 0.f;
+#endif
+}
+
 template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
               allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
 static __device__ __forceinline__ void mul_mat_q(
@@ -7340,6 +7441,12 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
     dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
 }
 
+template<typename dst_t>
+static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb = (k + QK_K - 1) / QK_K;
+    dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
 template <typename src_t, typename dst_t>
 static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -7385,6 +7492,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
             return dequantize_row_iq1_s_cuda;
         case GGML_TYPE_IQ4_NL:
             return dequantize_row_iq4_nl_cuda;
+        case GGML_TYPE_IQ4_XS:
+            return dequantize_row_iq4_xs_cuda;
         case GGML_TYPE_IQ3_S:
             return dequantize_row_iq3_s_cuda;
         case GGML_TYPE_F32:
@@ -7428,6 +7537,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
             return dequantize_row_iq1_s_cuda;
         case GGML_TYPE_IQ4_NL:
             return dequantize_row_iq4_nl_cuda;
+        case GGML_TYPE_IQ4_XS:
+            return dequantize_row_iq4_xs_cuda;
         case GGML_TYPE_IQ3_S:
             return dequantize_row_iq3_s_cuda;
         case GGML_TYPE_F16:
@@ -9176,6 +9287,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
             return max_compute_capability >= CC_RDNA2 ? 128 : 64;
         default:
@@ -9203,6 +9315,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
             return max_compute_capability >= CC_VOLTA ? 128 : 64;
         case GGML_TYPE_Q6_K:
@@ -9313,6 +9426,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
             mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
                 (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
+        case GGML_TYPE_IQ4_XS:
+            mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
         case GGML_TYPE_IQ3_S:
             mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
                 (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
@@ -12041,7 +12158,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 ggml_type a_type = a->type;
                 if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
                     a_type == GGML_TYPE_IQ1_S   || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S   ||
-                    a_type == GGML_TYPE_IQ2_S) {
+                    a_type == GGML_TYPE_IQ2_S   || a_type == GGML_TYPE_IQ4_XS) {
                     if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
                         return false;
                     }
index ffd24c86590def9e78317a137bf8caf8e1d62e15..71fcca5605914ee263996816c1084f040285e49f 100644 (file)
@@ -65,6 +65,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
     GGML_METAL_KERNEL_TYPE_RMS_NORM,
     GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -91,6 +92,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
   //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -113,6 +115,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -132,6 +135,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -151,6 +155,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_F16,
     GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -466,6 +471,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,            get_rows_iq2_s,         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,            get_rows_iq1_s,         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,           get_rows_iq4_nl,        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,           get_rows_iq4_xs,        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,              get_rows_i32,           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                  rms_norm,               ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                group_norm,             ctx->support_simdgroup_reduction);
@@ -492,6 +498,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,          mul_mv_iq2_s_f32,       ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,          mul_mv_iq1_s_f32,       ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,         mul_mv_iq4_nl_f32,      ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,         mul_mv_iq4_xs_f32,      ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,         mul_mv_id_f32_f32,      ctx->support_simdgroup_reduction);
       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,         mul_mv_id_f16_f16,      ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,         mul_mv_id_f16_f32,      ctx->support_simdgroup_reduction);
@@ -514,6 +521,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,       mul_mv_id_iq2_s_f32,    ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,       mul_mv_id_iq1_s_f32,    ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,      mul_mv_id_iq4_nl_f32,   ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,      mul_mv_id_iq4_xs_f32,   ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,            mul_mm_f32_f32,         ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,            mul_mm_f16_f32,         ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,           mul_mm_q4_0_f32,        ctx->support_simdgroup_mm);
@@ -533,6 +541,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,          mul_mm_iq2_s_f32,       ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,          mul_mm_iq1_s_f32,       ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,         mul_mm_iq4_nl_f32,      ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,         mul_mm_iq4_xs_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,         mul_mm_id_f32_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,         mul_mm_id_f16_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,        mul_mm_id_q4_0_f32,     ctx->support_simdgroup_mm);
@@ -552,6 +561,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,       mul_mm_id_iq2_s_f32,    ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,       mul_mm_id_iq1_s_f32,    ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,      mul_mm_id_iq4_nl_f32,   ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,      mul_mm_id_iq4_xs_f32,   ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                  rope_f32,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                  rope_f16,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                 alibi_f32,              true);
@@ -1371,6 +1381,7 @@ static bool ggml_metal_graph_compute(
                                 case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
+                                case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
                                 default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
                             }
 
@@ -1529,6 +1540,12 @@ static bool ggml_metal_graph_compute(
                                         nth1 = 16;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
                                     } break;
+                                case GGML_TYPE_IQ4_XS:
+                                    {
+                                        nth0 = 4;
+                                        nth1 = 16;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
+                                    } break;
                                 default:
                                     {
                                         GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1576,7 +1593,7 @@ static bool ggml_metal_graph_compute(
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src0t == GGML_TYPE_IQ4_NL) {
+                            else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
                                 const int mem_size = 32*sizeof(float);
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1678,6 +1695,7 @@ static bool ggml_metal_graph_compute(
                                 case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32  ].pipeline; break;
                                 case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
+                                case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
                                 default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
                             }
 
@@ -1839,6 +1857,12 @@ static bool ggml_metal_graph_compute(
                                         nth1 = 16;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
                                     } break;
+                                case GGML_TYPE_IQ4_XS:
+                                    {
+                                        nth0 = 4;
+                                        nth1 = 16;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
+                                    } break;
                                 default:
                                     {
                                         GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1902,7 +1926,7 @@ static bool ggml_metal_graph_compute(
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_IQ4_NL) {
+                            else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
                                 const int mem_size = 32*sizeof(float);
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1952,6 +1976,7 @@ static bool ggml_metal_graph_compute(
                             case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S  ].pipeline; break;
                             case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S  ].pipeline; break;
                             case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
+                            case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
                             case GGML_TYPE_I32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32    ].pipeline; break;
                             default: GGML_ASSERT(false && "not implemented");
                         }
index 47354e952944018e7e9f56d4da3e8d7afbbac97a..6894119035b9c8b38d54710bb78aa1c5fe586c51 100644 (file)
@@ -2560,6 +2560,13 @@ typedef struct {
     uint8_t qs[QK4_NL/2];
 } block_iq4_nl;
 
+typedef struct {
+    half d;
+    uint16_t scales_h;
+    uint8_t  scales_l[QK_K/64];
+    uint8_t  qs[QK_K/2];
+} block_iq4_xs;
+
 //====================================== dot products =========================
 
 void kernel_mul_mv_q2_K_f32_impl(
@@ -5160,6 +5167,100 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     }
 }
 
+void kernel_mul_mv_iq4_xs_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup float  * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    const int first_row = (r0 * 2 + sgitg) * 2;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
+    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+
+    const int ix = tiisg/16;  // 0 or 1
+    const int it = tiisg%16;  // 0...15
+    const int ib = it/2;
+    const int il = it%2;
+
+    shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    float4 yl[4];
+    float sumf[2]={0.f}, all_sum;
+
+    device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
+
+    uint32_t aux32[2];
+    thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+    float4 qf1, qf2;
+
+    for (int ibl = ix; ibl < nb; ibl += 2) {
+
+        device const float4 * y4 = (device const float4 *)yb;
+        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+        for (int row = 0; row < 2; ++row) {
+
+            device const block_iq4_xs & xb = x[row*nb + ibl];
+            device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
+
+            float4 acc1 = {0.f}, acc2 = {0.f};
+
+            aux32[0] = q4[0] & 0x0f0f0f0f;
+            aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
+            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+            acc1 += yl[0] * qf1;
+            acc2 += yl[1] * qf2;
+
+            aux32[0] = q4[1] & 0x0f0f0f0f;
+            aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
+            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+            acc1 += yl[2] * qf1;
+            acc2 += yl[3] * qf2;
+
+            acc1 += acc2;
+
+            const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
+            sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+        }
+
+        yb += 2 * QK_K;
+    }
+
+    for (int row = 0; row < 2; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+        }
+    }
+}
+
 [[host_name("kernel_mul_mv_iq1_s_f32")]]
 kernel void kernel_mul_mv_iq1_s_f32(
         device const  void * src0,
@@ -5217,6 +5318,35 @@ kernel void kernel_mul_mv_iq4_nl_f32(
     kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
+[[host_name("kernel_mul_mv_iq4_xs_f32")]]
+kernel void kernel_mul_mv_iq4_xs_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup float * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
 //============================= templates and their specializations =============================
 
 // NOTE: this is not dequantizing - we are simply fitting the template
@@ -5638,6 +5768,26 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
     }
 }
 
+template <typename type4x4>
+void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
+    const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
+    const float d = (float)xb->d * (ls - 32);
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    for (int i = 0; i < 4; ++i) {
+        aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
+        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+    }
+}
+
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 kernel void kernel_get_rows(
         device const  void * src0,
@@ -6183,7 +6333,8 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
 template [[host_name("kernel_get_rows_iq3_s")]]   kernel get_rows_t kernel_get_rows<block_iq3_s,   QK_NL, dequantize_iq3_s>;
 template [[host_name("kernel_get_rows_iq2_s")]]   kernel get_rows_t kernel_get_rows<block_iq2_s,   QK_NL, dequantize_iq2_s>;
 template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_t kernel_get_rows<block_iq1_s,   QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_t kernel_get_rows<block_iq4_nl,  2, dequantize_iq4_nl>;
+template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_t kernel_get_rows<block_iq4_nl,  2,     dequantize_iq4_nl>;
+template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_t kernel_get_rows<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
 
 //
 // matrix-matrix multiplication
@@ -6226,7 +6377,8 @@ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_m
 template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq3_s,   QK_NL, dequantize_iq3_s>;
 template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq2_s,   QK_NL, dequantize_iq2_s>;
 template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq1_s,   QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_nl,  2, dequantize_iq4_nl>;
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_nl,  2,     dequantize_iq4_nl>;
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
 
 //
 // indirect matrix-matrix multiplication
@@ -6281,7 +6433,8 @@ template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel
 template [[host_name("kernel_mul_mm_id_iq3_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s,   QK_NL, dequantize_iq3_s>;
 template [[host_name("kernel_mul_mm_id_iq2_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s,   QK_NL, dequantize_iq2_s>;
 template [[host_name("kernel_mul_mm_id_iq1_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s,   QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl,  2, dequantize_iq4_nl>;
+template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl,  2,     dequantize_iq4_nl>;
+template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
 
 //
 // matrix-vector multiplication
@@ -7507,3 +7660,68 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
         tiisg,
         sgitg);
 }
+
+[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
+kernel void kernel_mul_mv_id_iq4_xs_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         float * dst,
+        constant    uint64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant    uint64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        threadgroup float    * shared_values [[threadgroup(0)]],
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_iq4_xs_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        dst + bid*ne0,
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        shared_values,
+        tgpig,
+        tiisg,
+        sgitg);
+}
index 73c3bb4123da5ab9f247d6e56050873785e8dab7..607d50925b6da718c44350a2c326b99a894237d5 100644 (file)
@@ -4225,6 +4225,29 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y,
     }
 }
 
+void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const uint8_t * qs = x[i].qs;
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4);
+            const float dl = d * (ls - 32);
+            for (int j = 0; j < 16; ++j) {
+                y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf];
+                y[j+16] = dl * kvalues_iq4nl[qs[j] >>  4];
+            }
+            y  += 32;
+            qs += 16;
+        }
+    }
+}
+
 //===================================== Q8_K ==============================================
 
 void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -9675,8 +9698,8 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
             qs += 8;
 
             vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
-            vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
-            vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
             vs.val[0] = vceqq_u8(vs.val[0], mask2);
             vs.val[1] = vceqq_u8(vs.val[1], mask2);
 
@@ -9684,8 +9707,8 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
             q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
 
             vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
-            vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
-            vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
             vs.val[0] = vceqq_u8(vs.val[0], mask2);
             vs.val[1] = vceqq_u8(vs.val[1], mask2);
 
@@ -10425,6 +10448,134 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
 #endif
 }
 
+void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_K == 0);
+
+    const block_iq4_xs * restrict x = vx;
+    const block_q8_K   * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined __ARM_NEON
+    const int8x16_t values = vld1q_s8(kvalues_iq4nl);
+    const uint8x16_t m4b = vdupq_n_u8(0x0f);
+    uint8x16x2_t q4bits;
+    int8x16x4_t q4b;
+    int8x16x4_t q8b;
+    int32x4_t prod_1, prod_2;
+
+    float sumf = 0;
+
+    for (int ibl = 0; ibl < nb; ++ibl) {
+
+        const int8_t  * q8 = y[ibl].qs;
+        const uint8_t * q4 = x[ibl].qs;
+        uint16_t h = x[ibl].scales_h;
+
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib = 0; ib < QK_K/64; ++ib) {
+
+            q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
+            q8b    = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));
+            q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
+            q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));
+            q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
+
+            prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
+            prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
+
+            int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
+            int ls2 = ((x[ibl].scales_l[ib] >>  4) | ((h << 2) & 0x30)) - 32;
+            h >>= 4;
+            sumi1 += vaddvq_s32(prod_1) * ls1;
+            sumi2 += vaddvq_s32(prod_2) * ls2;
+
+        }
+
+        sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+
+    __m256 accum = _mm256_setzero_ps();
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const uint8_t * qs = x[ibl].qs;
+        const int8_t  * q8 = y[ibl].qs;
+        uint16_t sh = x[ibl].scales_h;
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs);  qs += 16;
+            const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs);  qs += 16;
+            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
+                                                   _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
+            const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
+                                                   _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
+            const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+            const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
+            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
+            sh >>= 4;
+            const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1));
+            const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2));
+            sumi1 = _mm256_add_epi32(p_1, sumi1);
+            sumi2 = _mm256_add_epi32(p_2, sumi2);
+        }
+        accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
+                _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum);
+    }
+
+    *s = hsum_float_8(accum);
+
+#else
+    float sumf = 0;
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
+        uint16_t h = x[ibl].scales_h;
+        const uint8_t * qs = x[ibl].qs;
+        const int8_t  * q8 = y[ibl].qs;
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
+            const uint8_t ls2 = (x[ibl].scales_l[ib/2] >>  4) | ((h << 2) & 0x30);
+            h >>= 4;
+            const float d1 = d4d8*(ls1 - 32);
+            const float d2 = d4d8*(ls2 - 32);
+            int sumi1 = 0, sumi2 = 0;
+            for (int j = 0; j < 16; ++j) {
+                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
+                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
+            }
+            sumf += d1 * (sumi1 + sumi2);
+            qs += 16;
+            q8 += 32;
+            sumi1 = sumi2 = 0;
+            for (int j = 0; j < 16; ++j) {
+                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
+                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
+            }
+            sumf += d2 * (sumi1 + sumi2);
+            qs += 16;
+            q8 += 32;
+        }
+    }
+    *s = sumf;
+#endif
+}
+
 // ================================ IQ2 quantization =============================================
 
 typedef struct {
@@ -12021,23 +12172,23 @@ static inline int best_index_int8(int n, const int8_t * val, float x) {
     return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
 }
 
-static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RESTRICT x,
-        ggml_fp16_t * dh, uint8_t * q4,
-        float * weight, uint8_t * L,
+static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
+        ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
+        float * scales, float * weight, uint8_t * L,
         const int8_t * values,
         const float * quant_weights) {
 
     const int ntry = 7;
 
     float sigma2 = 0;
-    for (int j = 0; j < QK4_NL; ++j) sigma2 += x[j]*x[j];
-    sigma2 *= 2.f/QK4_NL;
+    for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
+    sigma2 *= 2.f/super_block_size;
 
-    const int nb = QK4_NL/block_size;
+    memset(q4, 0, super_block_size/2);
+    dh[0] = GGML_FP32_TO_FP16(0.f);
 
-    memset(q4, 0, QK4_NL/2);
-    for (int ib = 0; ib < nb; ++ib) {
-        dh[ib] = GGML_FP32_TO_FP16(0.f);
+    float max_scale = 0, amax_scale = 0;
+    for (int ib = 0; ib < super_block_size/block_size; ++ib) {
         const float * xb = x + ib*block_size;
         if (quant_weights) {
             const float * qw = quant_weights + ib*block_size;
@@ -12053,6 +12204,7 @@ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RE
             }
         }
         if (!amax) {
+            scales[ib] = 0;
             continue;
         }
         float d = -max/values[0];
@@ -12066,7 +12218,6 @@ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RE
             sumqx += w*q*xb[j];
             sumq2 += w*q*q;
         }
-        float best_id = id;
         d = sumqx/sumq2;
         float best = d*sumqx;
         for (int itry = -ntry; itry <= ntry; ++itry) {
@@ -12082,15 +12233,47 @@ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RE
             }
             if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
                 d = sumqx/sumq2; best = d * sumqx;
-                best_id = id;
             }
         }
-        dh[ib] = GGML_FP32_TO_FP16(d);
-        for (int j = 0; j < block_size; ++j) {
-            L[ib*block_size + j] = best_index_int8(16, values, best_id*xb[j]);
+        scales[ib] = d;
+        float abs_d = fabsf(d);
+        if (abs_d > amax_scale) {
+            amax_scale = abs_d; max_scale = d;
         }
     }
-    for (int i = 0; i < QK4_NL/32; ++i) {
+
+    if (super_block_size/block_size > 1) {
+        int nb = super_block_size/block_size;
+        memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t));
+        float d = -max_scale/32;
+        dh[0] = GGML_FP32_TO_FP16(d);
+        float id = d ? 1/d : 0.f;
+        for (int ib = 0; ib < super_block_size/block_size; ++ib) {
+            int l = nearest_int(id*scales[ib]);
+            l = MAX(-32, MIN(31, l));
+            float dl = d * l;
+            float idl = dl ? 1/dl : 0.f;
+            uint8_t * Lb = L + ib*block_size;
+            const float * xb = x + ib*block_size;
+            for (int j = 0; j < block_size; ++j) {
+                Lb[j] = best_index_int8(16, values, idl*xb[j]);
+            }
+            l += 32;
+            uint8_t l_l = l & 0xf;
+            uint8_t l_h = l >>  4;
+            if (ib%2 == 0) scales_l[ib/2] = l_l;
+            else scales_l[ib/2] |= (l_l << 4);
+            scales_h[ib/8] |= (l_h << 2*(ib%8));
+        }
+    } else {
+        dh[0] = GGML_FP32_TO_FP16(scales[0]);
+        float id = scales[0] ? 1/scales[0] : 0;
+        for (int j = 0; j < super_block_size; ++j) {
+            L[j] = best_index_int8(16, values, id*x[j]);
+        }
+    }
+
+    for (int i = 0; i < super_block_size/32; ++i) {
         for (int j = 0; j < 16; ++j) {
             q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
         }
@@ -12103,12 +12286,16 @@ size_t quantize_iq4_nl(const float * src, void * dst, int nrow, int n_per_row, i
     int nblock = n_per_row/QK4_NL;
     char * qrow = (char *)dst;
     uint8_t L[QK4_NL];
-    float weight[32];
+    float weight[QK4_NL];
+    uint16_t unused_h;
+    uint8_t * unused_l = NULL;
+    float scale;
     for (int row = 0; row < nrow; ++row) {
         block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
         for (int ibl = 0; ibl < nblock; ++ibl) {
             const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
-            quantize_row_iq4_nl_impl(32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw);
+            quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
+                    &scale, weight, L, kvalues_iq4nl, qw);
         }
         src += n_per_row;
         qrow += nblock*sizeof(block_iq4_nl);
@@ -12127,6 +12314,38 @@ void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * rest
     quantize_iq4_nl(x, y, 1, k, NULL, NULL);
 }
 
+size_t quantize_iq4_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+    (void)hist;
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    uint8_t L[QK_K];
+    float weight[32];
+    float scales[QK_K/32];
+    for (int row = 0; row < nrow; ++row) {
+        block_iq4_xs * iq4 = (block_iq4_xs *)qrow;
+        for (int ibl = 0; ibl < nblock; ++ibl) {
+            const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
+            quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,
+                    scales, weight, L, kvalues_iq4nl, qw);
+        }
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq4_xs);
+    }
+    return nrow * nblock * sizeof(block_iq4_xs);
+}
+
+void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK_K == 0);
+    block_iq4_xs * restrict y = vy;
+    quantize_row_iq4_xs_reference(x, y, k);
+}
+
+void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int k) {
+    assert(k % QK_K == 0);
+    quantize_iq4_xs(x, y, 1, k, NULL, NULL);
+}
+
 // =============================== 2.5625 bpw
 
 static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
index 4731dde0cb5a960cfd200995c13c6ef3883c0634..2c61134c49e441b939129b8871ba466ac01ec9bd 100644 (file)
@@ -230,6 +230,14 @@ typedef struct {
 } block_iq4_nl;
 static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
 
+typedef struct {
+    ggml_fp16_t d;
+    uint16_t scales_h;
+    uint8_t  scales_l[QK_K/64];
+    uint8_t  qs[QK_K/2];
+} block_iq4_xs;
+static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -250,6 +258,7 @@ void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGM
 void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
 void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
 void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl  * GGML_RESTRICT y, int k);
+void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs  * GGML_RESTRICT y, int k);
 void quantize_row_iq3_s_reference  (const float * GGML_RESTRICT x, block_iq3_s   * GGML_RESTRICT y, int k);
 void quantize_row_iq2_s_reference  (const float * GGML_RESTRICT x, block_iq2_s   * GGML_RESTRICT y, int k);
 
@@ -268,6 +277,7 @@ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
 void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
+void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 void quantize_row_iq3_s  (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 void quantize_row_iq2_s  (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 
@@ -291,6 +301,7 @@ void dequantize_row_iq2_s  (const block_iq2_s   * GGML_RESTRICT x, float * GGML_
 void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq1_s  (const block_iq1_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq4_nl (const block_iq4_nl  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
+void dequantize_row_iq4_xs (const block_iq4_xs  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq3_s  (const block_iq3_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 
 // Dot product
@@ -311,6 +322,7 @@ void ggml_vec_dot_iq2_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const
 void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq3_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 //
@@ -322,6 +334,7 @@ size_t quantize_iq2_s  (const float * src, void * dst, int nrows, int n_per_row,
 size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_iq1_s  (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+size_t quantize_iq4_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_iq3_s  (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q2_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q3_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
diff --git a/ggml.c b/ggml.c
index ab6d90838064eaf0f124f478ef6c2bd1817f6cf0..a23ca6417ef730e2512df5468ebd76fa69bf6037 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -730,6 +730,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_0,
         .nrows                    = 1,
     },
+    [GGML_TYPE_IQ4_XS] = {
+        .type_name                = "iq4_xs",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq4_xs),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq4_xs,
+        .from_float               = quantize_row_iq4_xs,
+        .from_float_reference     = (ggml_from_float_t)quantize_row_iq4_xs_reference,
+        .vec_dot                  = ggml_vec_dot_iq4_xs_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
     [GGML_TYPE_Q8_K] = {
         .type_name                = "q8_K",
         .blck_size                = QK_K,
@@ -2338,6 +2350,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
         case GGML_FTYPE_MOSTLY_IQ3_XXS:       wtype = GGML_TYPE_IQ3_XXS;  break;
         case GGML_FTYPE_MOSTLY_IQ1_S:         wtype = GGML_TYPE_IQ1_S;    break;
         case GGML_FTYPE_MOSTLY_IQ4_NL:        wtype = GGML_TYPE_IQ4_NL;   break;
+        case GGML_FTYPE_MOSTLY_IQ4_XS:        wtype = GGML_TYPE_IQ4_XS;   break;
         case GGML_FTYPE_MOSTLY_IQ3_S:         wtype = GGML_TYPE_IQ3_S;    break;
         case GGML_FTYPE_MOSTLY_IQ2_S:         wtype = GGML_TYPE_IQ2_S;    break;
         case GGML_FTYPE_UNKNOWN:              wtype = GGML_TYPE_COUNT; break;
@@ -7776,6 +7789,7 @@ static void ggml_compute_forward_add(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ2_S:
             {
@@ -8057,6 +8071,7 @@ static void ggml_compute_forward_add1(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ2_S:
             {
@@ -8183,6 +8198,7 @@ static void ggml_compute_forward_acc(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ2_S:
         default:
@@ -11083,6 +11099,7 @@ static void ggml_compute_forward_out_prod(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ2_S:
             {
@@ -11273,6 +11290,7 @@ static void ggml_compute_forward_set(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ2_S:
         default:
@@ -11477,6 +11495,7 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ2_S:
             {
@@ -12179,6 +12198,7 @@ static void ggml_compute_forward_alibi(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ2_S:
         case GGML_TYPE_Q8_K:
@@ -12264,6 +12284,7 @@ static void ggml_compute_forward_clamp(
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ2_S:
         case GGML_TYPE_Q8_K:
@@ -19835,6 +19856,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
                 result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
                 GGML_ASSERT(result == row_size * nrows);
             } break;
+        case GGML_TYPE_IQ4_XS:
+            {
+                GGML_ASSERT(start % QK4_NL == 0);
+                GGML_ASSERT(start % n_per_row == 0);
+                size_t start_row = start / n_per_row;
+                size_t row_size = ggml_row_size(type, n_per_row);
+                result = quantize_iq4_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
+                GGML_ASSERT(result == row_size * nrows);
+            } break;
         case GGML_TYPE_F16:
             {
                 size_t elemsize = sizeof(ggml_fp16_t);
diff --git a/ggml.h b/ggml.h
index d21d09fc4c1bd11abc1d28781b25d3ec5d493b31..0a6d3c051fe72532d97c2b23c4f8ab749839b454 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -352,6 +352,7 @@ extern "C" {
         GGML_TYPE_IQ4_NL  = 20,
         GGML_TYPE_IQ3_S   = 21,
         GGML_TYPE_IQ2_S   = 22,
+        GGML_TYPE_IQ4_XS  = 23,
         GGML_TYPE_I8,
         GGML_TYPE_I16,
         GGML_TYPE_I32,
@@ -393,6 +394,7 @@ extern "C" {
         GGML_FTYPE_MOSTLY_IQ4_NL  = 19, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ3_S   = 20, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ2_S   = 21, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ4_XS  = 22, // except 1d tensors
     };
 
     // available tensor operations: