]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Fix Q4_K and Q5_K for QK_K = 64 on CUDA (#2359)
authorKawrakow <redacted>
Tue, 25 Jul 2023 10:48:04 +0000 (13:48 +0300)
committerGitHub <redacted>
Tue, 25 Jul 2023 10:48:04 +0000 (13:48 +0300)
* Fix Q4_K and Q5_K for QK_K = 64

* Very slightly better Q5_K bit fiddling

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-cuda.cu

index 87a1660613613c35215e802a790a161542dd7f70..d31fc79c10961de17704b317b7489f7a0c5e1c9e 100644 (file)
@@ -1564,12 +1564,14 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q4_K * bq4_K = (const block_q4_K *) vbq;
 
-    // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
-    const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
-
     float sumf_d = 0.0f;
     float sumf_m = 0.0f;
 
+#ifndef GGML_QKK_64
+
+    // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
+    const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
+
     const float    d = bq4_K->d;
     const float dmin = bq4_K->dmin;
 
@@ -1614,6 +1616,43 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
     }
 
     return d*sumf_d - dmin*sumf_m;
+
+#else
+
+    uint16_t aux16[2];
+    const uint8_t * s = (const uint8_t *)aux16;
+
+    const uint16_t * a = (const uint16_t *)bq4_K->scales;
+    aux16[0] = a[0] & 0x0f0f;
+    aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+    const float dall = bq4_K->d[0];
+    const float dmin = bq4_K->d[1];
+
+    const float d8_1 = bq8_1[0].d;
+    const float d8_2 = bq8_1[1].d;
+
+    const int ui1 = *((const int *)bq8_1[0].qs + iqs);
+    const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
+    const int ui3 = *((const int *)bq8_1[1].qs + iqs);
+    const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
+
+    const int * q4 = (const int *)bq4_K->qs + iqs;
+    const int v1 = q4[0];
+    const int v2 = q4[4];
+
+    const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
+    const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
+    const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
+    const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
+
+    sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
+    sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
+
+    return dall * sumf_d - dmin * sumf_m;
+
+#endif
+
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1625,6 +1664,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q5_K * bq5_K = (const block_q5_K *) vbq;
 
+#ifndef GGML_QKK_64
+
     const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
     const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
     const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
@@ -1680,6 +1721,42 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
     }
 
     return d*sumf_d - dmin*sumf_m;
+
+#else
+
+    const int8_t * s = bq5_K->scales;
+
+    const float d = bq5_K->d;
+
+    const float d8_1 = bq8_1[0].d;
+    const float d8_2 = bq8_1[1].d;
+
+    const int ui1 = *((const int *)bq8_1[0].qs + iqs);
+    const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
+    const int ui3 = *((const int *)bq8_1[1].qs + iqs);
+    const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
+
+    const int * ql = (const int *)bq5_K->qs + iqs;
+    const int vl1 = ql[0];
+    const int vl2 = ql[4];
+
+    const int step = 4 * iqs; // 0, 4, 8, 12
+    const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3
+    const int in = step%8; // 0, 4, 0, 4
+    const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
+
+    const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
+    const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
+    const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
+    const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
+
+    const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
+                       + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
+
+    return d * sumf_d;
+
+#endif
+
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A