]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : optimize dequant q6_K kernel (#11892)
authorAdrian Kretz <redacted>
Sat, 15 Feb 2025 18:39:20 +0000 (19:39 +0100)
committerGitHub <redacted>
Sat, 15 Feb 2025 18:39:20 +0000 (20:39 +0200)
ggml/src/ggml-metal/ggml-metal.metal

index da415184b173ca286980474111de937bc2be3d3a..83e7ac9f411ef3f0c74fcb0aade6d691641d3a77 100644 (file)
@@ -373,24 +373,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
 template <typename type4x4>
 void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
     const half d_all = xb->d;
-    device const uint8_t * ql = (device const uint8_t *)xb->ql;
-    device const uint8_t * qh = (device const uint8_t *)xb->qh;
+    device const uint16_t * ql = (device const uint16_t *)xb->ql;
+    device const uint16_t * qh = (device const uint16_t *)xb->qh;
     device const int8_t * scales = (device const int8_t *)xb->scales;
 
-    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
-    qh = qh + 32*(il/8) + 16*(il&1);
+    ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
+    qh = qh + 16*(il/8) + 8*(il&1);
     float sc = scales[(il%2) + 2 * ((il/2))];
     il = (il/2) & 3;
 
-    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
-    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
-    const float       coef = il>1 ? 1.f/16.f          : 1.f;
+    const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
+    const uint32_t kmask2 = il>1 ? 0xF0F0F0F0                       : 0x0F0F0F0F;
     const float ml = d_all * sc * 32.f;
-    const float dl = d_all * sc * coef;
-    for (int i = 0; i < 16; ++i) {
-        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
-                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
-        reg[i/4][i%4] = dl * q - ml;
+    const float dl0 = d_all * sc;
+    const float dl1 = dl0 / 256.f;
+    const float dl2 = dl0 / (256.f * 256.f);
+    const float dl3 = dl0 / (256.f * 256.f * 256.f);
+    const uint8_t shr_h = il>2 ? 2 : 0;
+    const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
+    const uint8_t shr_l = il>1 ? 4 : 0;
+    for (int i = 0; i < 4; ++i) {
+        const uint32_t  low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
+        const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
+        const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
+        reg[i][0] = dl0 *  ((half)(q & 0xFF))       - ml;
+        reg[i][1] = dl1 * ((float)(q & 0xFF00))     - ml;
+        reg[i][2] = dl2 * ((float)(q & 0xFF0000))   - ml;
+        reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
     }
 }