]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
sync : llama.cpp (ggml/0)
authorGeorgi Gerganov <redacted>
Wed, 21 Feb 2024 14:19:39 +0000 (16:19 +0200)
committerGeorgi Gerganov <redacted>
Thu, 22 Feb 2024 13:12:36 +0000 (15:12 +0200)
ggml-ci

examples/common-ggml.cpp
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml-quants.c
ggml-quants.h
ggml.c
ggml.h

index 32c97e89061c3b848dabbab5b3ee2457b0d4ad44..f1c7f6ef1874e8fedb9e1472a70d86e2fb5bea45 100644 (file)
@@ -66,6 +66,7 @@ bool ggml_common_quantize_0(
         case GGML_FTYPE_MOSTLY_IQ2_XS:
         case GGML_FTYPE_MOSTLY_IQ3_XXS:
         case GGML_FTYPE_MOSTLY_IQ1_S:
+        case GGML_FTYPE_MOSTLY_IQ4_NL:
                 {
                     fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
                     return false;
@@ -199,6 +200,7 @@ bool ggml_common_quantize_0(
                 case GGML_TYPE_IQ2_XS:
                 case GGML_TYPE_IQ3_XXS:
                 case GGML_TYPE_IQ1_S:
+                case GGML_TYPE_IQ4_NL:
                 case GGML_TYPE_COUNT:
                     {
                         fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
index 6caae56b0244b0789336a8e23d5ef88d7ddbf6a2..e7c211d7d6087075eba4aa80d7fb815db867de0e 100644 (file)
@@ -528,6 +528,15 @@ typedef struct {
 } block_iq1_s;
 static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
 
+#define QK4_NL 32
+#define QR4_NL 2
+#define QI4_NL (QK4_NL / (4*QR4_NL))
+typedef struct {
+    half d;
+    uint8_t qs[QK4_NL/2];
+} block_iq4_nl;
+static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
+
 #define WARP_SIZE 32
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
@@ -1987,6 +1996,26 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
 
 }
 
+static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int i   = blockIdx.x;
+    const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
+
+    const int tid = threadIdx.x;
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+    const uint8_t  * q4 = x[ib].qs + 4*il;
+    const float d = (float)x[ib].d;
+    for (int j = 0; j < 4; ++j) {
+        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
+    }
+
+}
 
 static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
@@ -4732,6 +4761,56 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
 #endif
 }
 
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values,
+        int & val1, int & val2) {
+
+    uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
+    aux32 = q4 & 0x0f0f0f0f;
+    uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
+    uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
+    val1 = v1 | (v2 << 16);
+    aux32 = (q4 >> 4) & 0x0f0f0f0f;
+    v1 = values[q8[0]] | (values[q8[1]] << 8);
+    v2 = values[q8[2]] | (values[q8[3]] << 8);
+    val2 = v1 | (v2 << 16);
+}
+#endif
+
+static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
+    const int32_t  * q8 = (const int32_t  *)bq8_1->qs + iqs;
+
+    const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
+
+    int v1, v2;
+    int sumi1 = 0, sumi2 = 0;
+    for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
+        const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
+        get_int_from_table_16(aux, values, v1, v2);
+        sumi1 = __dp4a(v1, q8[l+0], sumi1);
+        sumi2 = __dp4a(v2, q8[l+4], sumi2);
+    }
+
+#else
+    const uint8_t * q4 = bq->qs + 4*iqs;
+    const int8_t  * q8 = bq8_1->qs + 4*iqs;
+
+    int sumi1 = 0, sumi2 = 0;
+    for (int l = 0; l < 4*VDR_Q4_0_Q8_1_MMVQ; ++l) {
+        sumi1 += q8[l+ 0] * kvalues_iq4nl[q4[l] & 0xf];
+        sumi2 += q8[l+16] * kvalues_iq4nl[q4[l] >>  4];
+    }
+#endif
+    const float d = (float)bq->d * __low2float(bq8_1->ds);
+    return d * (sumi1 + sumi2);
+}
+
 template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
               allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
 static __device__ __forceinline__ void mul_mat_q(
@@ -6777,6 +6856,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, c
     dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
 }
 
+template<typename dst_t>
+static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb = (k + QK_K - 1) / QK_K;
+    dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
+}
+
 template <typename src_t, typename dst_t>
 static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -6818,6 +6903,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
             return dequantize_row_iq3_xxs_cuda;
         case GGML_TYPE_IQ1_S:
             return dequantize_row_iq1_s_cuda;
+        case GGML_TYPE_IQ4_NL:
+            return dequantize_row_iq4_nl_cuda;
         case GGML_TYPE_F32:
             return convert_unary_cuda<float>;
         default:
@@ -6855,6 +6942,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
             return dequantize_row_iq3_xxs_cuda;
         case GGML_TYPE_IQ1_S:
             return dequantize_row_iq1_s_cuda;
+        case GGML_TYPE_IQ4_NL:
+            return dequantize_row_iq4_nl_cuda;
         case GGML_TYPE_F16:
             return convert_unary_cuda<half>;
         default:
@@ -8599,6 +8688,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_NL:
             return max_compute_capability >= CC_RDNA2 ? 128 : 64;
         default:
             GGML_ASSERT(false);
@@ -8623,6 +8713,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_NL:
             return max_compute_capability >= CC_VOLTA ? 128 : 64;
         case GGML_TYPE_Q6_K:
             return 64;
@@ -8724,6 +8815,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
             mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
                 (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
+        case GGML_TYPE_IQ4_NL:
+            mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
         default:
             GGML_ASSERT(false);
             break;
@@ -11446,7 +11541,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                     return false;
                 }
                 ggml_type a_type = a->type;
-                if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ1_S) {
+                if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
+                    a_type == GGML_TYPE_IQ1_S   || a_type == GGML_TYPE_IQ4_NL) {
                     if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
                         return false;
                     }
index 956e323a0d053cb75fd57c8dead79e72b8603317..0d4aa43093739724913d8cd8bd119ee657320c58 100644 (file)
@@ -62,6 +62,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
     GGML_METAL_KERNEL_TYPE_RMS_NORM,
     GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -85,6 +86,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
   //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -104,6 +106,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -120,6 +123,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -136,6 +140,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_F16,
     GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -448,6 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,           get_rows_iq2_xs,        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,          get_rows_iq3_xxs,       true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,            get_rows_iq1_s,         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,           get_rows_iq4_nl,        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,              get_rows_i32,           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                  rms_norm,               ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                group_norm,             ctx->support_simdgroup_reduction);
@@ -471,6 +477,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,         mul_mv_iq2_xs_f32,      ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,        mul_mv_iq3_xxs_f32,     ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,          mul_mv_iq1_s_f32,       ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,         mul_mv_iq4_nl_f32,      ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,         mul_mv_id_f32_f32,      ctx->support_simdgroup_reduction);
       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,         mul_mv_id_f16_f16,      ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,         mul_mv_id_f16_f32,      ctx->support_simdgroup_reduction);
@@ -490,6 +497,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,      mul_mv_id_iq2_xs_f32,   ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,     mul_mv_id_iq3_xxs_f32,  ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,       mul_mv_id_iq1_s_f32,    ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,      mul_mv_id_iq4_nl_f32,   ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,            mul_mm_f32_f32,         ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,            mul_mm_f16_f32,         ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,           mul_mm_q4_0_f32,        ctx->support_simdgroup_mm);
@@ -506,6 +514,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,         mul_mm_iq2_xs_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,        mul_mm_iq3_xxs_f32,     ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,          mul_mm_iq1_s_f32,       ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,         mul_mm_iq4_nl_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,         mul_mm_id_f32_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,         mul_mm_id_f16_f32,      ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,        mul_mm_id_q4_0_f32,     ctx->support_simdgroup_mm);
@@ -522,6 +531,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,      mul_mm_id_iq2_xs_f32,   ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,     mul_mm_id_iq3_xxs_f32,  ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,       mul_mm_id_iq1_s_f32,    ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,      mul_mm_id_iq4_nl_f32,   ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                  rope_f32,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                  rope_f16,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                 alibi_f32,              true);
@@ -1338,6 +1348,7 @@ static bool ggml_metal_graph_compute(
                                 case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
                                 case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
                                 case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32  ].pipeline; break;
+                                case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
                                 default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
                             }
 
@@ -1478,6 +1489,12 @@ static bool ggml_metal_graph_compute(
                                         nth1 = 16;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
                                     } break;
+                                case GGML_TYPE_IQ4_NL:
+                                    {
+                                        nth0 = 4;
+                                        nth1 = 16;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
+                                    } break;
                                 default:
                                     {
                                         GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1525,6 +1542,11 @@ static bool ggml_metal_graph_compute(
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
+                            else if (src0t == GGML_TYPE_IQ4_NL) {
+                                const int mem_size = 32*sizeof(float);
+                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                            }
                             else if (src0t == GGML_TYPE_Q4_K) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
@@ -1619,6 +1641,7 @@ static bool ggml_metal_graph_compute(
                                 case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
                                 case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
                                 case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32  ].pipeline; break;
+                                case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
                                 default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
                             }
 
@@ -1762,6 +1785,12 @@ static bool ggml_metal_graph_compute(
                                         nth1 = 16;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
                                     } break;
+                                case GGML_TYPE_IQ4_NL:
+                                    {
+                                        nth0 = 4;
+                                        nth1 = 16;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
+                                    } break;
                                 default:
                                     {
                                         GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1825,6 +1854,11 @@ static bool ggml_metal_graph_compute(
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
+                            else if (src2t == GGML_TYPE_IQ4_NL) {
+                                const int mem_size = 32*sizeof(float);
+                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                            }
                             else if (src2t == GGML_TYPE_Q4_K) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
@@ -1867,6 +1901,7 @@ static bool ggml_metal_graph_compute(
                             case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
                             case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
                             case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S  ].pipeline; break;
+                            case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
                             case GGML_TYPE_I32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32    ].pipeline; break;
                             default: GGML_ASSERT(false && "not implemented");
                         }
index f0d77d446755cefcddb74a9fa2f8912f62d859ee..c223a981c246a39286647847ef5a3eab81fef9f8 100644 (file)
@@ -2531,6 +2531,12 @@ typedef struct {
     uint8_t scales[QK_K/16];
 } block_iq1_s;
 
+// Non-linear quants
+#define QK4_NL 32
+typedef struct {
+    half    d;
+    uint8_t qs[QK4_NL/2];
+} block_iq4_nl;
 
 //====================================== dot products =========================
 
@@ -4384,7 +4390,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
     const uint i13 = im/ne12;
 
     const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
-
     device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
     device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
 
@@ -4447,6 +4452,103 @@ void kernel_mul_mv_iq1_s_f32_impl(
     }
 }
 
+constexpr constant static float kvalues_iq4nl_f[16] = {
+    -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
+};
+
+void kernel_mul_mv_iq4_nl_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup float  * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    const int nb = ne00/QK4_NL;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    const int first_row = (r0 * 2 + sgitg) * 2;
+    const int ib_row = first_row * nb;
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
+    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+
+    const int ix = tiisg/2;  // 0...15
+    const int it = tiisg%2;  // 0 or 1
+
+    shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    float4 yl[4];
+    float sumf[2]={0.f}, all_sum;
+
+    device const float * yb = y + ix * QK4_NL + it * 8;
+
+    uint32_t aux32[2];
+    thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+    float4 qf1, qf2;
+
+    for (int ib = ix; ib < nb; ib += 16) {
+
+        device const float4 * y4 = (device const float4 *)yb;
+        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+        for (int row = 0; row < 2; ++row) {
+
+            device const block_iq4_nl & xb = x[row*nb + ib];
+            device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
+
+            float4 acc1 = {0.f}, acc2 = {0.f};
+
+            aux32[0] = q4[0] | (q4[1] << 16);
+            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+            aux32[0] &= 0x0f0f0f0f;
+            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+            acc1 += yl[0] * qf1;
+            acc2 += yl[1] * qf2;
+
+            aux32[0] = q4[2] | (q4[3] << 16);
+            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+            aux32[0] &= 0x0f0f0f0f;
+            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+            acc1 += yl[2] * qf1;
+            acc2 += yl[3] * qf2;
+
+            acc1 += acc2;
+
+            sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+        }
+
+        yb += 16 * QK4_NL;
+    }
+
+    for (int row = 0; row < 2; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+        }
+    }
+}
+
 [[host_name("kernel_mul_mv_iq1_s_f32")]]
 kernel void kernel_mul_mv_iq1_s_f32(
         device const  void * src0,
@@ -4475,6 +4577,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
     kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
 }
 
+[[host_name("kernel_mul_mv_iq4_nl_f32")]]
+kernel void kernel_mul_mv_iq4_nl_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup float * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
 
 //============================= templates and their specializations =============================
 
@@ -4838,6 +4968,21 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
     }
 }
 
+template <typename type4x4>
+void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+    const float d = xb->d;
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    for (int i = 0; i < 4; ++i) {
+        aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
+        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+    }
+}
+
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 kernel void kernel_get_rows(
         device const  void * src0,
@@ -5381,6 +5526,7 @@ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_r
 template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_t kernel_get_rows<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
 template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
 template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_t kernel_get_rows<block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_t kernel_get_rows<block_iq4_nl,  2, dequantize_iq4_nl>;
 
 //
 // matrix-matrix multiplication
@@ -5421,6 +5567,7 @@ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_m
 template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
 template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
 template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_nl,  2, dequantize_iq4_nl>;
 
 //
 // indirect matrix-matrix multiplication
@@ -5473,6 +5620,7 @@ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel
 template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
 template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
 template [[host_name("kernel_mul_mm_id_iq1_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s,   QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl,  2, dequantize_iq4_nl>;
 
 //
 // matrix-vector multiplication
@@ -6503,3 +6651,68 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
         tiisg,
         sgitg);
 }
+
+[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
+kernel void kernel_mul_mv_id_iq4_nl_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         float * dst,
+        constant    uint64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant    uint64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        threadgroup float    * shared_values [[threadgroup(0)]],
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_iq4_nl_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        dst + bid*ne0,
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        shared_values,
+        tgpig,
+        tiisg,
+        sgitg);
+}
index 3319d2ccf1812eac82b2770b0bf3dc81a52f0ba6..6336538f0e99ec79010c590ec65bcf9a1f69fe7f 100644 (file)
@@ -3754,6 +3754,26 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in
     }
 }
 
+static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
+void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) {
+    assert(k % QK4_NL == 0);
+    const int nb = k / QK4_NL;
+
+    for (int i = 0; i < nb; i++) {
+
+        const uint8_t * qs = x[i].qs;
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        for (int j = 0; j < QK4_NL/2; ++j) {
+            y[j+       0] = d * kvalues_iq4nl[qs[j] & 0xf];
+            y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >>  4];
+        }
+        y  += QK4_NL;
+        qs += QK4_NL/2;
+    }
+}
+
 //===================================== Q8_K ==============================================
 
 void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -9148,7 +9168,6 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
 #endif
 }
 
-// TODO
 void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
@@ -9452,7 +9471,100 @@ void ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const
     *s = sumf;
 
 #endif
+}
+
+void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK4_NL == 0);
+    static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
+
+    const block_iq4_nl * restrict x = vx;
+    const block_q8_0   * restrict y = vy;
+
+    const int nb = n / QK4_NL;
+
+#if defined __ARM_NEON
+    const int8x16_t values = vld1q_s8(kvalues_iq4nl);
+    const uint8x16_t m4b = vdupq_n_u8(0x0f);
+    uint8x16x2_t q4bits;
+    int8x16x4_t q4b;
+    int8x16x4_t q8b;
+    int32x4_t prod_1, prod_2;
 
+    float sumf = 0;
+
+    for (int ib = 0; ib < nb; ib += 2) {
+
+        q4bits.val[0] = vld1q_u8(x[ib+0].qs);
+        q4bits.val[1] = vld1q_u8(x[ib+1].qs);
+        q8b.val[0]    = vld1q_s8(y[ib+0].qs);
+        q8b.val[1]    = vld1q_s8(y[ib+0].qs + 16);
+        q8b.val[2]    = vld1q_s8(y[ib+1].qs);
+        q8b.val[3]    = vld1q_s8(y[ib+1].qs + 16);
+
+        q4b.val[0] = vqtbl1q_s8(values, vandq_u8(q4bits.val[0], m4b));
+        q4b.val[1] = vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
+        q4b.val[2] = vqtbl1q_s8(values, vandq_u8(q4bits.val[1], m4b));
+        q4b.val[3] = vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
+
+        prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
+        prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
+
+        sumf += (float)x[ib+0].d * (float)y[ib+0].d * vaddvq_s32(prod_1) + (float)x[ib+1].d * (float)y[ib+1].d * vaddvq_s32(prod_2);
+
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+    const __m256i mone = _mm256_set1_epi16(1);
+
+    __m256 accum1 = _mm256_setzero_ps();
+    __m256 accum2 = _mm256_setzero_ps();
+    for (int ib = 0; ib < nb; ib += 2) {
+        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs);
+        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs);
+        const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs);
+        const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs);
+        const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
+                                               _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
+        const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
+                                               _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
+        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+        const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
+        const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
+        accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
+                _mm256_cvtepi32_ps(p_1), accum1);
+        accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
+                _mm256_cvtepi32_ps(p_2), accum2);
+
+        y += 2;
+        x += 2;
+    }
+
+    *s = hsum_float_8(_mm256_add_ps(accum1, accum2));
+
+#else
+    float sumf = 0;
+    for (int ib = 0; ib < nb; ++ib) {
+        const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
+        int sumi1 = 0, sumi2 = 0;
+        for (int j = 0; j < QK4_NL/2; ++j) {
+            sumi1 += y[ib].qs[j+       0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
+            sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >>  4];
+        }
+        sumf += d * (sumi1 + sumi2);
+    }
+    *s = sumf;
+#endif
 }
 
 // ================================ IQ2 quantization =============================================
@@ -10729,3 +10841,123 @@ size_t quantize_iq1_s(const float * src, void * dst, int nrow, int n_per_row, in
     }
     return nrow * nblock * sizeof(block_iq1_s);
 }
+
+// ============================ 4-bit non-linear quants
+
+static inline int best_index_int8(int n, const int8_t * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RESTRICT x,
+        ggml_fp16_t * dh, uint8_t * q4,
+        float * weight, uint8_t * L,
+        const int8_t * values,
+        const float * quant_weights) {
+
+    const int ntry = 7;
+
+    float sigma2 = 0;
+    for (int j = 0; j < QK4_NL; ++j) sigma2 += x[j]*x[j];
+    sigma2 *= 2.f/QK4_NL;
+
+    const int nb = QK4_NL/block_size;
+
+    memset(q4, 0, QK4_NL/2);
+    for (int ib = 0; ib < nb; ++ib) {
+        dh[ib] = GGML_FP32_TO_FP16(0.f);
+        const float * xb = x + ib*block_size;
+        if (quant_weights) {
+            const float * qw = quant_weights + ib*block_size;
+            for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+        } else {
+            for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
+        }
+        float amax = 0, max = 0;
+        for (int j = 0; j < block_size; ++j) {
+            float ax = fabsf(xb[j]);
+            if (ax > amax) {
+                amax = ax; max = xb[j];
+            }
+        }
+        if (!amax) {
+            continue;
+        }
+        float d = -max/values[0];
+        float id = 1/d;
+        float sumqx = 0, sumq2 = 0;
+        for (int j = 0; j < block_size; ++j) {
+            float al = id*xb[j];
+            int l = best_index_int8(16, values, al);
+            float q = values[l];
+            float w = weight[j];
+            sumqx += w*q*xb[j];
+            sumq2 += w*q*q;
+        }
+        float best_id = id;
+        d = sumqx/sumq2;
+        float best = d*sumqx;
+        for (int itry = -ntry; itry <= ntry; ++itry) {
+            id = (itry + values[0])/max;
+            sumqx = sumq2 = 0;
+            for (int j = 0; j < block_size; ++j) {
+                float al = id*xb[j];
+                int l = best_index_int8(16, values, al);
+                float q = values[l];
+                float w = weight[j];
+                sumqx += w*q*xb[j];
+                sumq2 += w*q*q;
+            }
+            if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                d = sumqx/sumq2; best = d * sumqx;
+                best_id = id;
+            }
+        }
+        dh[ib] = GGML_FP32_TO_FP16(d);
+        for (int j = 0; j < block_size; ++j) {
+            L[ib*block_size + j] = best_index_int8(16, values, best_id*xb[j]);
+        }
+    }
+    for (int i = 0; i < QK4_NL/32; ++i) {
+        for (int j = 0; j < 16; ++j) {
+            q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
+        }
+    }
+}
+
+size_t quantize_iq4_nl(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+    (void)hist;
+    GGML_ASSERT(n_per_row%QK4_NL == 0);
+    int nblock = n_per_row/QK4_NL;
+    char * qrow = (char *)dst;
+    uint8_t L[QK4_NL];
+    float weight[32];
+    for (int row = 0; row < nrow; ++row) {
+        block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
+        for (int ibl = 0; ibl < nblock; ++ibl) {
+            const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
+            quantize_row_iq4_nl_impl(32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw);
+        }
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq4_nl);
+    }
+    return nrow * nblock * sizeof(block_iq4_nl);
+}
+
+void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK4_NL == 0);
+    block_iq4_nl * restrict y = vy;
+    quantize_row_iq4_nl_reference(x, y, k);
+}
+
+void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
+    assert(k % QK4_NL == 0);
+    quantize_iq4_nl(x, y, 1, k, NULL, NULL);
+}
+
index ad381cfab022d21c48bc8f524cb78d6d565d8a57..113623b62938a48e9170944894f6279d933c2f2f 100644 (file)
@@ -198,6 +198,14 @@ typedef struct {
 } block_iq1_s;
 static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
 
+// Non-linear quants
+#define QK4_NL 32
+typedef struct {
+    ggml_fp16_t d;
+    uint8_t qs[QK4_NL/2];
+} block_iq4_nl;
+static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -217,6 +225,7 @@ void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGM
 void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int k);
 void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
 void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
+void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl  * GGML_RESTRICT y, int k);
 
 void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
@@ -232,6 +241,7 @@ void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
 void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
+void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 
 // Dequantization
 void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
@@ -251,6 +261,7 @@ void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_
 void dequantize_row_iq2_xs (const block_iq2_xs  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 void dequantize_row_iq1_s  (const block_iq1_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
+void dequantize_row_iq4_nl (const block_iq4_nl  * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 
 // Dot product
 void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -268,6 +279,7 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
 void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 //
 // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
@@ -276,6 +288,7 @@ size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row,
 size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_iq1_s  (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q2_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q3_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 size_t quantize_q4_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
diff --git a/ggml.c b/ggml.c
index b90c7df2ca5ea4bee546a57f3405d7cbe866d6a8..5b9fa741a64799b151ad5a37bc960525036aecaf 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -690,6 +690,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_K,
         .nrows                    = 1,
     },
+    [GGML_TYPE_IQ4_NL] = {
+        .type_name                = "iq4_nl",
+        .blck_size                = QK4_NL,
+        .type_size                = sizeof(block_iq4_nl),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq4_nl,
+        .from_float               = quantize_row_iq4_nl,
+        .from_float_reference     = (ggml_from_float_t)quantize_row_iq4_nl_reference,
+        .vec_dot                  = ggml_vec_dot_iq4_nl_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+    },
     [GGML_TYPE_Q8_K] = {
         .type_name                = "q8_K",
         .blck_size                = QK_K,
@@ -2291,6 +2303,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
         case GGML_FTYPE_MOSTLY_IQ2_XS:        wtype = GGML_TYPE_IQ2_XS;   break;
         case GGML_FTYPE_MOSTLY_IQ3_XXS:       wtype = GGML_TYPE_IQ3_XXS;  break;
         case GGML_FTYPE_MOSTLY_IQ1_S:         wtype = GGML_TYPE_IQ1_S;    break;
+        case GGML_FTYPE_MOSTLY_IQ4_NL:        wtype = GGML_TYPE_IQ4_NL;   break;
         case GGML_FTYPE_UNKNOWN:              wtype = GGML_TYPE_COUNT; break;
         case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
     }
@@ -7724,6 +7737,7 @@ static void ggml_compute_forward_add(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_NL:
             {
                 ggml_compute_forward_add_q_f32(params, dst);
             } break;
@@ -8002,6 +8016,7 @@ static void ggml_compute_forward_add1(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_NL:
             {
                 ggml_compute_forward_add1_q_f32(params, dst);
             } break;
@@ -8125,6 +8140,7 @@ static void ggml_compute_forward_acc(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_NL:
         default:
             {
                 GGML_ASSERT(false);
@@ -11022,6 +11038,7 @@ static void ggml_compute_forward_out_prod(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_NL:
             {
                 ggml_compute_forward_out_prod_q_f32(params, dst);
             } break;
@@ -11209,6 +11226,7 @@ static void ggml_compute_forward_set(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_NL:
         default:
             {
                 GGML_ASSERT(false);
@@ -11410,6 +11428,7 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_NL:
             {
                 ggml_compute_forward_get_rows_q(params, dst);
             } break;
@@ -12109,6 +12128,7 @@ static void ggml_compute_forward_alibi(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_NL:
         case GGML_TYPE_Q8_K:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
@@ -12191,6 +12211,7 @@ static void ggml_compute_forward_clamp(
         case GGML_TYPE_IQ2_XS:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_NL:
         case GGML_TYPE_Q8_K:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
@@ -19725,6 +19746,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
                 result = quantize_iq1_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
                 GGML_ASSERT(result == row_size * nrows);
             } break;
+        case GGML_TYPE_IQ4_NL:
+            {
+                GGML_ASSERT(start % QK4_NL == 0);
+                GGML_ASSERT(start % n_per_row == 0);
+                size_t start_row = start / n_per_row;
+                size_t row_size = ggml_row_size(type, n_per_row);
+                result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
+                GGML_ASSERT(result == row_size * nrows);
+            } break;
         case GGML_TYPE_F16:
             {
                 size_t elemsize = sizeof(ggml_fp16_t);
diff --git a/ggml.h b/ggml.h
index 004d09c700f3264d78bcbd9f2613398a21d16ba7..bed7a36a0ee6a3e61da0ce160764b9d568f0c178 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -355,6 +355,7 @@ extern "C" {
         GGML_TYPE_IQ2_XS  = 17,
         GGML_TYPE_IQ3_XXS = 18,
         GGML_TYPE_IQ1_S   = 19,
+        GGML_TYPE_IQ4_NL  = 20,
         GGML_TYPE_I8,
         GGML_TYPE_I16,
         GGML_TYPE_I32,
@@ -393,6 +394,7 @@ extern "C" {
         GGML_FTYPE_MOSTLY_IQ2_XS  = 16, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ1_S   = 18, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ4_NL  = 19, // except 1d tensors
     };
 
     // available tensor operations: