]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Dequant improvements rebase (llama/8255)
authorAidanBeltonS <redacted>
Wed, 3 Jul 2024 01:55:34 +0000 (02:55 +0100)
committerGeorgi Gerganov <redacted>
Mon, 8 Jul 2024 10:03:28 +0000 (13:03 +0300)
* Single load for half2

* Store scales in local mem

* Vec load quantized values

src/ggml-sycl/common.hpp
src/ggml-sycl/convert.cpp
src/ggml-sycl/dequantize.hpp

index dfd4a7c2c606bb811070718f3226445436a402b1..476d847ca575e325d1b0a0d41714bd3c3728f645 100644 (file)
@@ -351,4 +351,10 @@ static __dpct_inline__ float warp_reduce_max(float x,
     return x;
 }
 
+// Helper for vec loading aligned data
+template <typename Tp, int n>
+inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
+    return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);
+}
+
 #endif // GGML_SYCL_COMMON_HPP
index ce9de2b42b7220b9630778c635caae5bb4da0a13..a15271b516fa8066742b15e0745f3ae1f7bdcf90 100644 (file)
@@ -152,12 +152,15 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
 
-        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+        stream->submit([&](sycl::handler &cgh) {
+            sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
+            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
                                                    sycl::range<3>(1, 1, 32),
                                                sycl::range<3>(1, 1, 32)),
                              [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_q4_K(vx, y, item_ct1);
+                                 dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1);
                              });
+        });
     }
 }
 
index b6080d83a33eb86fdd46b640090026c7b3574806..ed8ad098bcb2fdec53c2bf79a3fbaa64a1311ee4 100644 (file)
@@ -293,7 +293,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
 #if QK_K == 256
 static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
     if (j < 4) {
-        d = q[j] & 63; m = q[j + 4] & 63;
+        d = q[j] & 63;
+        m = q[j + 4] & 63;
     } else {
         d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
         m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
@@ -303,7 +304,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
 
 template<typename dst_t>
 static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
-                                  const sycl::nd_item<3> &item_ct1) {
+                                  uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
     const block_q4_K * x = (const block_q4_K *) vx;
 
     const int i = item_ct1.get_group(2);
@@ -318,19 +319,26 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
 
     dst_t * y = yy + i*QK_K + 64*il + n*ir;
 
-    const float dall = x[i].dm[0];
-    const float dmin = x[i].dm[1];
+    const sycl::half2 dm = x[i].dm;
+    const float dall = dm[0];
+    const float dmin = dm[1];
 
-    const uint8_t * q = x[i].qs + 32*il + n*ir;
+    if (tid < 12)
+        scales_local[tid] = x[i].scales[tid];
+    item_ct1.barrier(sycl::access::fence_space::local_space);
 
     uint8_t sc, m;
-    get_scale_min_k4(is + 0, x[i].scales, sc, m);
-    const float d1 = dall * sc; const float m1 = dmin * m;
-    get_scale_min_k4(is + 1, x[i].scales, sc, m);
-    const float d2 = dall * sc; const float m2 = dmin * m;
+    get_scale_min_k4(is + 0, scales_local, sc, m);
+    const float d1 = dall * sc;
+    const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, scales_local, sc, m);
+    const float d2 = dall * sc;
+    const float m2 = dmin * m;
+
+    sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
     for (int l = 0; l < n; ++l) {
-        y[l + 0] = d1 * (q[l] & 0xF) - m1;
-        y[l +32] = d2 * (q[l] >>  4) - m2;
+        y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
+        y[l +32] = d2 * (q_vec[l] >>  4) - m2;
     }
 #else
     const int tid = item_ct1.get_local_id(2);