]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: small dequantization improvements (llama/18380)
authorEve <redacted>
Fri, 26 Dec 2025 17:12:11 +0000 (17:12 +0000)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 10:39:43 +0000 (12:39 +0200)
* iq4_xs

* quants

src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
src/ggml-vulkan/vulkan-shaders/types.glsl

index 70ee542d96952ccd3029f2f9b07f60402944b646..376944f1e2147bb7781030a0e2a5d4149f2afa7d 100644 (file)
@@ -401,13 +401,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
     const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
     const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
     const uint qshift = (iqs & 16) >> 2;
-    u8vec4 qs = u8vec4(
-        data_a[a_offset + ib].qs[iq + 0],
-        data_a[a_offset + ib].qs[iq + 1],
-        data_a[a_offset + ib].qs[iq + 2],
-        data_a[a_offset + ib].qs[iq + 3]
-    );
-    qs = (qs >> qshift) & uint8_t(0xF);
+    const u8vec4 qs = unpack8((data_a_packed32[a_offset + ib].qs[iq/4] >> qshift) & 0x0F0F0F0F);
 
     const float dl = float(int(sl | (sh << 4)) - 32);
     return dl * vec4(
index 58ede04400d86ae582e02e83e2ac8966655d0e9a..1a3531761aa42983e3ecce62af3106f5cd7f9aa4 100644 (file)
@@ -159,14 +159,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint is = iqs / 8;                     // 0..15
             const uint halfsplit = ((iqs % 64) / 16);    // 0,1,2,3
             const uint qsshift = halfsplit * 2;          // 0,2,4,6
-            const uint m = 1 << (4 * n + halfsplit);     // 1,2,4,8,16,32,64,128
 
             const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
                                   | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
             const float dl = float(data_a[ib].d) * float(us - 32);
 
-            buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi    ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi    ] & m) != 0) ? 0 : 4)),
-                                             dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
+            const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy);
+            const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy);
+
+            buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x),
+                                             dl * (qs.y - hm.y));
 #elif defined(DATA_A_Q4_K)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -198,8 +200,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const float d = loadd.x * sc;
             const float m = -loadd.y * mbyte;
 
-            buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF), m),
-                                             fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
+            const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy);
+
+            buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
+                                             fma(d, q.y, m));
 #elif defined(DATA_A_Q5_K)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -213,8 +217,6 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint qsi = n * 32 + (iqs % 16) * 2;  // 0,2,4..126
             const uint qhi = (iqs % 16) * 2;           // 0,2,4..30
 
-            const uint8_t hm = uint8_t(1 << (iqs / 16));
-
             const vec2 loadd = vec2(data_a[ib].dm);
 
             const uint scidx0 = (is < 4) ? is : (is + 4);
@@ -234,8 +236,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const float d = loadd.x * sc;
             const float m = -loadd.y * mbyte;
 
-            buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi    ] & hm) != 0 ? 16 : 0), m),
-                                             fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
+            const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F;
+            const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4;
+            const vec2 q = vec2(unpack8(qs | qh).xy);
+
+            buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
+                                             fma(d, q.y, m));
 #elif defined(DATA_A_Q6_K)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -394,11 +400,9 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
 
             const float d = float(data_a[ib].d);
             const uint qs = data_a[ib].qs[iqs];
-            const uint signs = pack32(u8vec4(
-                data_a[ib].qs[is+0],
-                data_a[ib].qs[is+1],
-                data_a[ib].qs[is+2],
-                data_a[ib].qs[is+3]
+            const uint signs = pack32(u16vec2(
+                data_a_packed16[ib].qs[is/2],
+                data_a_packed16[ib].qs[is/2+1]
             ));
             const float db = d * 0.5 * (0.5 + (signs >> 28));
             const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
@@ -443,8 +447,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
             const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
             const uint qshift = (idx & 8) >> 1;
-            u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
-            qs = (qs >> qshift) & uint8_t(0xF);
+            u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy;
 
             const float d = float(data_a[ib].d);
             const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
index 02578c77c4f310dd00843a7689f4566349406486..402a2a8397ddf86b1c9e5f1627e7692285269fa2 100644 (file)
@@ -172,16 +172,12 @@ struct block_q8_0
     float16_t d;
     int8_t qs[32];
 };
+
 struct block_q8_0_packed16
 {
     float16_t d;
     int16_t qs[32/2];
 };
-struct block_q8_0_packed32
-{
-    float16_t d;
-    int32_t qs[32/4];
-};
 
 #if defined(DATA_A_Q8_0)
 #define QUANT_K QUANT_K_Q8_0
@@ -189,7 +185,6 @@ struct block_q8_0_packed32
 #define QUANT_AUXF 1
 #define A_TYPE block_q8_0
 #define A_TYPE_PACKED16 block_q8_0_packed16
-#define A_TYPE_PACKED32 block_q8_0_packed32
 #define DATA_A_QUANT_LEGACY
 #endif
 
@@ -201,11 +196,13 @@ struct block_q8_1
     f16vec2 ds;
     int8_t qs[32];
 };
+
 struct block_q8_1_packed16
 {
     f16vec2 ds;
     int16_t qs[16];
 };
+
 struct block_q8_1_packed32
 {
     f16vec2 ds;
@@ -218,6 +215,7 @@ struct block_q8_1_x4
     f16vec2 ds[4];
     int32_t qs[32];
 };
+
 struct block_q8_1_x4_packed128
 {
     f16vec2 ds[4];
@@ -1346,10 +1344,28 @@ struct block_iq4_xs
     uint8_t qs[QUANT_K_IQ4_XS/2];
 };
 
+struct block_iq4_xs_packed16
+{
+    float16_t d;
+    uint16_t scales_h;
+    uint16_t scales_l[QUANT_K_IQ4_XS/128];
+    uint16_t qs[QUANT_K_IQ4_XS/4];
+};
+
+struct block_iq4_xs_packed32
+{
+    float16_t d;
+    uint16_t scales_h;
+    uint32_t scales_l;
+    uint32_t qs[QUANT_K_IQ4_XS/8];
+};
+
 #if defined(DATA_A_IQ4_XS)
 #define QUANT_K QUANT_K_IQ4_XS
 #define QUANT_R QUANT_R_IQ4_XS
 #define A_TYPE block_iq4_xs
+#define A_TYPE_PACKED16 block_iq4_xs_packed16
+#define A_TYPE_PACKED32 block_iq4_xs_packed32
 #endif
 
 #define QUANT_K_IQ4_NL 32