]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Add q3_s and q1_s (llama/5886)
authorAbhilash Majumder <redacted>
Mon, 11 Mar 2024 04:57:56 +0000 (10:27 +0530)
committerGeorgi Gerganov <redacted>
Thu, 14 Mar 2024 16:46:58 +0000 (18:46 +0200)
* Add q3_s and q1_s

* fix compilation

* fix build

* fix build

* fix build

* enable ops

* rm macro

* increase grid space

src/ggml-sycl.cpp

index 6fafa5282c8f99638fd3b0cb319ab01865ca7e19..6d56845821f694d82ed70835d8a6d33d1bcf3651 100644 (file)
@@ -3494,6 +3494,31 @@ typedef struct dpct_type_block_iq3_xxs {
 } block_iq3_xxs;
 static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
 
+#define QR3_XS 8
+#define QI3_XS (QK_K / (4*QR3_XS))
+#if QK_K == 64
+#define IQ3S_N_SCALE 2
+#else
+#define IQ3S_N_SCALE QK_K/64
+#endif
+typedef struct {
+    sycl::half d;
+    uint8_t qs[QK_K/4];
+    uint8_t qh[QK_K/32];
+    uint8_t signs[QK_K/8];
+    uint8_t scales[IQ3S_N_SCALE];
+} block_iq3_s;
+static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
+
+#define QR1_S 8
+#define QI1_S (QK_K / (4*QR1_S))
+typedef struct {
+    sycl::half d;
+    uint8_t qs[QK_K/8];
+    uint8_t scales[QK_K/16];
+} block_iq1_s;
+static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
+
 #define WARP_SIZE 32
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
@@ -4833,6 +4858,62 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
 
 }
 
+template<typename dst_t>
+static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
+                                     const sycl::nd_item<3> &item_ct1,
+                                     const uint32_t *iq3s_grid,
+                                     const uint8_t *ksigns_iq2xs,
+                                     const uint8_t *kmask_iq2xs) {
+
+    const int i = item_ct1.get_group(2);
+    const block_iq3_s * x = (const block_iq3_s  *) vx;
+
+    const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint8_t  * qs = x[i].qs + 8*ib;
+    const uint8_t  * grid1 = (const uint8_t *)(iq3s_grid + qs[2*il+0]);
+    const uint8_t  * grid2 = (const uint8_t *)(iq3s_grid + qs[2*il+1]);
+    const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
+    const uint8_t signs = x[i].signs[4*ib + il];
+    for (int j = 0; j < 4; ++j) {
+        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+    }
+#else
+    assert(false);
+#endif
+
+}
+
+template<typename dst_t>
+static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
+                                     const sycl::nd_item<3> &item_ct1,
+                                     const uint64_t *iq1s_grid,
+                                     const uint8_t *ksigns_iq2xs,
+                                     const uint8_t *kmask_iq2xs) {
+
+    const int i = item_ct1.get_group(2);
+    const block_iq1_s * x = (const block_iq1_s  *) vx;
+
+    const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const int i8 = 4*ib+il;
+    uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
+    const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
+    const float d = (float)x[i].d * (2*(h & 7) + 1);
+    for (int j = 0; j < 8; ++j) y[j] = d * grid[j];
+#else
+    assert(false);
+#endif
+
+}
+
 /*
 DPCT1110:4: The total declared local variable size in device function
 dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
@@ -7679,6 +7760,76 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
 #endif
 }
 
+static __dpct_inline__ float
+vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
+                     const block_q8_1 *__restrict__ bq8_1, const int &iqs,
+                     const uint32_t *iq3s_grid, const uint64_t *ksigns64) {
+#if DPCT_COMPATIBILITY_TEMP >=                                                 \
+    MIN_CC_DP4A // lowest compute capability for integer intrinsics
+#if QK_K == 256
+    const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
+
+    const int ib32 = iqs;
+    const uint8_t  * qs = bq2->qs + 8*ib32;
+    const int8_t   * q8 = bq8_1[ib32].qs;
+    int sumi = 0;
+    for (int l = 0; l < 4; ++l) {
+        const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
+        const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
+        uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
+            ((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
+        uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
+            ((bq2->signs[4*ib32+l] >>  4) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
+        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+            grid1[0] ^ signs0, signs0, std::minus<>());
+        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+            grid2[0] ^ signs1, signs1, std::minus<>());
+        sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
+        sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
+        q8 += 8;
+    }
+    const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * bq8_1[ib32].ds[0];
+    return d * sumi;
+#else
+    assert(false);
+    return 0.f;
+#endif
+#else
+    assert(false);
+    return 0.f;
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
+                     const block_q8_1 *__restrict__ bq8_1, const int &iqs,
+                     const uint64_t *iq1s_grid, const uint64_t *ksigns64) {
+#if QK_K == 256
+    const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
+
+    const int ib32 = iqs;
+    int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
+    const uint8_t h1 = bq1->scales[2*ib32+0];
+    const uint8_t h2 = bq1->scales[2*ib32+1];
+    const int * q8 = (const int *)bq8_1[ib32].qs;
+    const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
+    const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
+    const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
+    const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
+    for (int j = 0; j < 2; ++j) {
+        sumi1 = dpct::dp4a(q8[j+0], grid1[j], sumi1);
+        sumi2 = dpct::dp4a(q8[j+2], grid2[j], sumi2);
+        sumi3 = dpct::dp4a(q8[j+4], grid3[j], sumi3);
+        sumi4 = dpct::dp4a(q8[j+6], grid4[j], sumi4);
+    }
+    const float d = (float)bq1->d  * bq8_1[ib32].ds[0];
+    return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
+                sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));
+#else
+    assert(false);
+    return 0.f;
+#endif
+}
 
 template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
           int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
@@ -8444,6 +8595,98 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void
     }
 }
 
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
+                          const sycl::nd_item<3> &item_ct1,
+                          const uint32_t *iq3s_grid_ptr, const uint64_t *ksigns64_ptr ) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+         i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs =
+            vdr *
+            (item_ct1.get_local_id(2) %
+             (qi / vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid_ptr, ksigns64_ptr);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
+                          const sycl::nd_item<3> &item_ct1,
+                          const uint64_t *iq1s_grid_ptr, const uint64_t *ksigns64_ptr ) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+         i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs =
+            vdr *
+            (item_ct1.get_local_id(2) %
+             (qi / vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr, ksigns64_ptr);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
 static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
                                    const sycl::nd_item<3> &item_ct1) {
@@ -10129,6 +10372,64 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
     }
 }
 
+template <typename dst_t>
+static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
+                                        dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+    {
+        iq3s_grid.init(*stream);
+        ksigns_iq2xs.init(*stream);
+        kmask_iq2xs.init(*stream);
+
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->submit([&](sycl::handler &cgh) {
+            auto iq3s_grid_ptr_ct1 = iq3s_grid.get_ptr();
+            auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr();
+            auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr();
+
+            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_iq3_s(
+                                     vx, y, item_ct1, iq3s_grid_ptr_ct1,
+                                     ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
+                             });
+        });
+    }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
+                                        dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+    {
+        iq1s_grid.init(*stream);
+        ksigns_iq2xs.init(*stream);
+        kmask_iq2xs.init(*stream);
+
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->submit([&](sycl::handler &cgh) {
+            auto iq1s_grid_ptr_ct1 = iq1s_grid.get_ptr();
+            auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr();
+            auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr();
+
+            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_iq1_s(
+                                     vx, y, item_ct1, iq1s_grid_ptr_ct1,
+                                     ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
+                             });
+        });
+    }
+}
+
 template <typename src_t, typename dst_t>
 static void convert_unary_sycl(const void *__restrict__ vx,
                                dst_t *__restrict__ y, const int k,
@@ -10179,6 +10480,10 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try {
             return dequantize_row_iq2_xs_sycl;
         case GGML_TYPE_IQ3_XXS:
             return dequantize_row_iq3_xxs_sycl;
+        case GGML_TYPE_IQ3_S:
+            return dequantize_row_iq3_s_sycl;
+        case GGML_TYPE_IQ1_S:
+            return dequantize_row_iq1_s_sycl;
         case GGML_TYPE_F32:
             return convert_unary_sycl<float>;
         default:
@@ -10219,6 +10524,10 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
             return dequantize_row_iq2_xs_sycl;
         case GGML_TYPE_IQ3_XXS:
             return dequantize_row_iq3_xxs_sycl;
+        case GGML_TYPE_IQ3_S:
+            return dequantize_row_iq3_s_sycl;
+        case GGML_TYPE_IQ1_S:
+            return dequantize_row_iq1_s_sycl;
         case GGML_TYPE_F16:
             return convert_unary_sycl<sycl::half>;
         default:
@@ -10808,6 +11117,61 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
     }
 }
 
+static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
+                                          float *dst, const int ncols,
+                                          const int nrows,
+                                          dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+        iq3s_grid.init(*stream);
+        ksigns64.init(*stream);
+
+        stream->submit([&](sycl::handler &cgh) {
+            auto iq3s_grid_ptr_ct1 = iq3s_grid.get_ptr();
+            auto ksigns64_ptr_ct1 = ksigns64.get_ptr();
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_XS, block_iq3_s, 1>(
+                            vx, vy, dst, ncols, nrows, item_ct1,
+                            iq3s_grid_ptr_ct1, ksigns64_ptr_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
+                                          float *dst, const int ncols,
+                                          const int nrows,
+                                          dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+        iq1s_grid.init(*stream);
+        ksigns64.init(*stream);
+
+        stream->submit([&](sycl::handler &cgh) {
+            auto iq1s_grid_ptr_ct1 = iq1s_grid.get_ptr();
+            auto ksigns64_ptr_ct1 = ksigns64.get_ptr();
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
+                            vx, vy, dst, ncols, nrows, item_ct1,
+                            iq1s_grid_ptr_ct1, ksigns64_ptr_ct1);
+                    });
+        });
+    }
+}
 
 static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
                                         float *dst, const int ncols_x,
@@ -13556,8 +13920,11 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
         case GGML_TYPE_Q5_K:
         case GGML_TYPE_IQ2_XXS:
         case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ3_XXS:
             return max_compute_capability >= VER_GEN9 ? 128 : 64;
+        case GGML_TYPE_IQ3_S:
+            return max_compute_capability >= VER_GEN9 ? 128 : 64;
         case GGML_TYPE_Q6_K:
             return 64;
         default:
@@ -13618,6 +13985,12 @@ inline void ggml_sycl_op_mul_mat_vec_q(
         case GGML_TYPE_IQ3_XXS:
             mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
             break;
+        case GGML_TYPE_IQ3_S:
+            mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_IQ1_S:
+            mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
         default:
             GGML_ASSERT(false);
             break;
@@ -16963,9 +17336,8 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
                     return false;
                 }
                 ggml_type a_type = a->type;
-                if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
-                    a_type == GGML_TYPE_IQ1_S   || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S   ||
-                    a_type == GGML_TYPE_IQ2_S   || a_type == GGML_TYPE_IQ4_XS) {
+                if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ2_S ||
+                    a_type == GGML_TYPE_IQ4_XS) {
                     return false;
                 }
                 return true;