} block_q6_K;
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
+#define QR2_XXS 8
+#define QI2_XXS (QK_K / (4*QR2_XXS))
+typedef struct {
+ half d;
+ uint16_t qs[QK_K/8];
+} block_iq2_xxs;
+static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
+
#define WARP_SIZE 32
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
#endif
}
+static const __device__ uint64_t kgrid_iq2xxs[256] = {
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
+};
+
+static const __device__ uint8_t ksigns_iq2xs[128] = {
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
+};
+
+static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
+
+inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ return true;
+ default:
+ return false;
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int i = blockIdx.x;
+ const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
+
+ const int tid = threadIdx.x;
+#if QK_K == 256
+ const int il = tid/8; // 0...3
+ const int ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint16_t * q2 = x[i].qs + 4*ib;
+ const uint8_t * aux8 = (const uint8_t *)q2;
+ const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[il]);
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+#else
+ assert(false);
+#endif
+
+}
+
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
}
+static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+#if QK_K == 256
+ const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
+
+#if QR2_XXS == 8
+ const int ib32 = iqs;
+ const uint16_t * q2 = bq2->qs + 4*ib32;
+ const uint8_t * aux8 = (const uint8_t *)q2;
+ const int8_t * q8 = bq8_1[ib32].qs;
+ uint32_t aux32 = q2[2] | (q2[3] << 16);
+ int sumi = 0;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[l]);
+ const uint8_t signs = ksigns_iq2xs[aux32 & 127];
+ for (int j = 0; j < 8; ++j) {
+ sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+ }
+ q8 += 8;
+ aux32 >>= 7;
+ }
+ const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f;
+ return d * sumi;
+#else
+ // iqs is 0...15
+ const int ib32 = iqs/2;
+ const int il = iqs%2;
+ const uint16_t * q2 = bq2->qs + 4*ib32;
+ const uint8_t * aux8 = (const uint8_t *)q2;
+ const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
+ const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
+ const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
+ const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
+ const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
+ const int8_t * q8 = bq8_1[ib32].qs + 16*il;
+ int sumi1 = 0, sumi2 = 0;
+ for (int j = 0; j < 8; ++j) {
+ sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
+ sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
+ }
+ return d * (sumi1 + sumi2);
+#endif
+#else
+ assert(false);
+ return 0.f;
+#endif
+}
+
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
static __device__ __forceinline__ void mul_mat_q(
#endif
}
+template<typename dst_t>
+static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
+ case GGML_TYPE_IQ2_XXS:
+ return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
default:
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
+ case GGML_TYPE_IQ2_XXS:
+ return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_F16:
return convert_unary_cuda<half>;
default:
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
+static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
static void ggml_mul_mat_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
default:
GGML_ASSERT(false);
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
+ case GGML_TYPE_IQ2_XXS:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q6_K:
return 64;
case GGML_TYPE_Q6_K:
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
break;
+ case GGML_TYPE_IQ2_XXS:
+ mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+ break;
default:
GGML_ASSERT(false);
break;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
+
// debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(get_rows_i32);
+ GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(group_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
GGML_METAL_DECL_KERNEL(rope_f32);
GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(get_rows_i32);
+ GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(group_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
}
GGML_METAL_ADD_KERNEL(rope_f32);
GGML_METAL_ADD_KERNEL(rope_f16);
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
GGML_METAL_DEL_KERNEL(get_rows_i32);
+ GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(group_norm);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
}
GGML_METAL_DEL_KERNEL(rope_f32);
GGML_METAL_DEL_KERNEL(rope_f16);
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
+ case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
} break;
+ case GGML_TYPE_IQ2_XXS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
+ } break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
+ //src0t == GGML_TYPE_IQ2_XXS ||
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
+ else if (src0t == GGML_TYPE_IQ2_XXS) {
+ [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
else if (src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
+ case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
} break;
+ case GGML_TYPE_IQ2_XXS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
+ } break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
+ //src2t == GGML_TYPE_IQ2_XXS ||
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
+ else if (src2t == GGML_TYPE_IQ2_XXS) {
+ [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
else if (src2t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
+ case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
default: GGML_ASSERT(false && "not implemented");
}
} block_q6_K;
// 210 bytes / block
+typedef struct {
+ half d;
+ uint16_t qs[QK_K/8];
+} block_iq2_xxs;
+// 66 bytes / block for QK_K = 256, so 2.0625 bpw
+
//====================================== dot products =========================
void kernel_mul_mv_q2_K_f32_impl(
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
}
+// ======================= "True" 2-bit
+
+constexpr constant static uint64_t kgrid_iq2xxs[256] = {
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
+};
+
+constexpr constant static uint8_t ksigns_iq2xs[128] = {
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
+};
+
+constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
+
+void kernel_mul_mv_iq2_xxs_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
+ {
+ int nval = 4;
+ int pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i];
+ nval = 2;
+ pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+#if QK_K == 256
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq2_xxs * xr = x + ibl;
+ device const uint16_t * q2 = xr->qs + 4 * ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ device const uint8_t * aux8 = (device const uint8_t *)q2;
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
+ const float d = db * (0.5f + (aux32 >> 28));
+
+ float sum = 0;
+ for (int l = 0; l < 4; ++l) {
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
+ for (int j = 0; j < 8; ++j) {
+ sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ }
+ sumf[row] += d * sum;
+
+ dh += nb*sizeof(block_iq2_xxs)/2;
+ q2 += nb*sizeof(block_iq2_xxs)/2;
+ }
+
+ y4 += 32 * 32;
+ }
+#else
+ // TODO
+#endif
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
+kernel void kernel_mul_mv_iq2_xxs_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
//============================= templates and their specializations =============================
// NOTE: this is not dequantizing - we are simply fitting the template
}
}
+template <typename type4x4>
+void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
+ device const uint16_t * q2 = xb->qs + 4*ib32;
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
+ constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
+ for (int i = 0; i < 8; ++i) {
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+ grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
+ for (int i = 0; i < 8; ++i) {
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+}
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows(
device const void * src0,
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
//
// matrix-matrix multiplication
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
//
// indirect matrix-matrix multiplication
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
//
// matrix-vector multiplication
tiisg,
sgitg);
}
+
+[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
+kernel void kernel_mul_mv_id_iq2_xxs_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_iq2_xxs_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ shared_values,
+ tgpig,
+ tiisg,
+ sgitg);
+}
return (n/QK_K*sizeof(block_q6_K));
}
+// ====================== "True" 2-bit (de)-quantization
+
+void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) {
+ (void)x;
+ (void)y;
+ (void)k;
+ assert(k % QK_K == 0);
+ //fprintf(stderr, "=========================== %s: not implemented\n", __func__);
+}
+
+static const uint64_t iq2xxs_grid[256] = {
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
+};
+
+static const uint8_t ksigns_iq2xs[128] = {
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
+};
+static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
+
+void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+
+ uint32_t aux32[2];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));
+ const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
+ for (int j = 0; j < 8; ++j) {
+ y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ y += 8;
+ }
+ }
+ }
+}
+
+void quantize_row_iq2_xxs(const float * restrict x, void * restrict vy, int k) {
+ assert(k % QK_K == 0);
+ block_iq2_xxs * restrict y = vy;
+ quantize_row_iq2_xxs_reference(x, y, k);
+}
+
+size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist) {
+ assert(k % QK_K == 0);
+ (void)hist; // TODO: collect histograms
+
+ for (int j = 0; j < n; j += k) {
+ block_iq2_xxs * restrict y = (block_iq2_xxs *)dst + j/QK_K;
+ quantize_row_iq2_xxs_reference(src + j, y, k);
+ }
+ return (n/QK_K*sizeof(block_iq2_xxs));
+}
+
//===================================== Q8_K ==============================================
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
x += QK_K;
continue;
}
- const float iscale = -128.f/max;
+ //const float iscale = -128.f/max;
+ // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
+ const float iscale = -127.f/max;
for (int j = 0; j < QK_K; ++j) {
int v = nearest_int(iscale*x[j]);
y[i].qs[j] = MIN(127, v);
}
#endif
+
+static const int8_t keven_signs_q2xs[1024] = {
+ 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
+ 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
+ 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
+ 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
+ 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
+ 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
+ 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
+ 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
+ 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
+ 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
+ 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
+ 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
+ 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
+ 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
+ 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
+ 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
+ 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
+ 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
+ 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
+ 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
+ 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
+ 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
+ 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
+ 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
+ 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
+ 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
+ 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
+ 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
+ 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
+ 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
+ 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
+ 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
+};
+
+void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ assert(n % QK_K == 0);
+
+ const block_iq2_xxs * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ uint32_t aux32[4];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ int8x16x4_t q2u;
+ int8x16x4_t q2s;
+ int8x16x4_t q8b;
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ float sumf1 = 0, sumf2 = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ q8b = vld1q_s8_x4(q8); q8 += 64;
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+ q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
+ q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
+ q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
+ q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
+ q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
+ q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
+ q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
+ q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
+ q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
+ q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
+ q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
+ q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
+ }
+ sumf += d*(sumf1 + sumf2);
+ }
+ *s = 0.25f * sumf;
+
+#elif defined(__AVX2__)
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ uint32_t aux32[4];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ __m256i sumi1 = _mm256_setzero_si256();
+ __m256i sumi2 = _mm256_setzero_si256();
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
+ signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
+ const uint16_t ls1 = aux32[1] >> 28;
+ const uint16_t ls2 = aux32[3] >> 28;
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+ sumi1 = _mm256_add_epi32(sumi1, p1);
+ sumi2 = _mm256_add_epi32(sumi2, p2);
+ }
+
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+ }
+
+ *s = 0.125f * hsum_float_8(accumf);
+
+#else
+
+ uint32_t aux32[2];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ float sumf = 0.f;
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ int32_t bsum = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ memcpy(aux32, q2, 2*sizeof(uint32_t));
+ q2 += 4;
+ const uint32_t ls = 2*(aux32[1] >> 28) + 1;
+ int32_t sumi = 0;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
+ for (int j = 0; j < 8; ++j) {
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+ }
+ q8 += 8;
+ }
+ bsum += sumi * ls;
+ }
+ sumf += d * bsum;
+ }
+ *s = 0.125f * sumf;
+#endif
+}
} block_q8_K;
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
+// (Almost) "true" 2-bit quantization.
+// Due to the need to use blocks as per ggml dsign, it ends up using
+// 2.0625 bpw because of the 16-bit scale for each block of 256.
+typedef struct {
+ ggml_fp16_t d;
+ uint16_t qs[QK_K/8];
+} block_iq2_xxs;
+static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
// Quantization
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
+void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
+void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
// Dequantization
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
+void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
// Dot product
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
.vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
+ [GGML_TYPE_IQ2_XXS] = {
+ .type_name = "iq2_xxs",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq2_xxs),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
+ .from_float = quantize_row_iq2_xxs,
+ .from_float_reference = (ggml_from_float_t) quantize_row_iq2_xxs_reference,
+ .vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ },
[GGML_TYPE_Q8_K] = {
.type_name = "q8_K",
.blck_size = QK_K,
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
+ case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
}
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
{
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
{
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
default:
{
GGML_ASSERT(false);
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
{
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
default:
{
GGML_ASSERT(false);
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
{
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_Q8_K:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_Q8_K:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
} break;
+ case GGML_TYPE_IQ2_XXS:
+ {
+ GGML_ASSERT(start % QK_K == 0);
+ block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
+ result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
+ } break;
case GGML_TYPE_F16:
{
int elemsize = sizeof(ggml_fp16_t);
GGML_TYPE_Q5_K = 13,
GGML_TYPE_Q6_K = 14,
GGML_TYPE_Q8_K = 15,
+ GGML_TYPE_IQ2_XXS = 16,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
};
// available tensor operations:
GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
+ GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);