]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: more mul mat optimizations (llama/18533)
authorEve <redacted>
Wed, 7 Jan 2026 10:13:17 +0000 (10:13 +0000)
committerGeorgi Gerganov <redacted>
Wed, 14 Jan 2026 07:11:59 +0000 (09:11 +0200)
* q4_k

* q5_k

* q2_k

* q4_1

* q5_1

* better buf index

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 376944f1e2147bb7781030a0e2a5d4149f2afa7d..7865a6bda79e08d7ef1c2085e9d570f0093e0765 100644 (file)
@@ -462,7 +462,8 @@ vec2 get_dm(uint ib, uint a_offset) {
 
 #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
 vec2 get_dm(uint ib, uint a_offset) {
-    return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
+    const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm);
+    return dm;
 }
 #endif
 
index 1a3531761aa42983e3ecce62af3106f5cd7f9aa4..ce7f2d699a212b22b50c434e97d4c42df1c3cdea 100644 (file)
@@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
 #endif
 #elif defined(DATA_A_Q4_0)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
-            const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
+            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
 
             const uint ib = idx / 4;
             const uint iqs = idx & 0x03;
@@ -63,16 +63,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
 #elif defined(DATA_A_Q4_1)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
-            const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
+            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
 
             const uint ib = idx / 4;
             const uint iqs = idx & 0x03;
 
-            const float d = float(data_a_packed16[ib].d);
-            const float m = float(data_a_packed16[ib].m);
-            const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
-            const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
-            const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
+            const vec2 dm = vec2(data_a_packed32[ib].dm);
+            const uint vui = data_a_packed32[ib].qs[iqs];
+            const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;
+            const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y;
 
             buf_a[buf_idx     ] = FLOAT_TYPE_VEC2(v0.xy);
             buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
@@ -80,7 +79,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
 #elif defined(DATA_A_Q5_0)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
-            const uint buf_idx = col * SHMEM_STRIDE + row;
+            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
 
             const uint ib = idx / 8;
             const uint iqs = idx & 0x07;
@@ -97,22 +96,26 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
 #elif defined(DATA_A_Q5_1)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
-            const uint buf_idx = col * SHMEM_STRIDE + row;
-
-            const uint ib = idx / 8;
-            const uint iqs = idx & 0x07;
-
-            const float d = float(data_a_packed16[ib].d);
-            const float m = float(data_a_packed16[ib].m);
-            const uint uint_qh = data_a_packed16[ib].qh;
-            const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
-            const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
+            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
 
-            const uint vui = uint(data_a_packed16[ib].qs[iqs]);
-            const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
+            const uint ib = idx / 4;
+            const uint iqs = idx & 0x03;
 
-            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v.xz);
-            buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
+            const vec2 dm = vec2(data_a_packed32[ib].dm);
+            const uint uint_qh = data_a_packed32[ib].qh;
+            const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10);
+            const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10);
+            const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10);
+            const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10);
+
+            const uint vui = data_a_packed32[ib].qs[iqs];
+            const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y;
+            const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y;
+
+            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v0.xz);
+            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz);
+            buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw);
+            buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw);
 #elif defined(DATA_A_Q8_0)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -131,20 +134,21 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
 
-            const uint ib = idx / 128;                         // 2 values per idx
-            const uint iqs = idx % 128;                        // 0..127
+            const uint ib = idx / 64;                          // 4 values per idx
+            const uint iqs = (idx % 64) * 2;                   // 0,2,4..126
 
             const uint qsi = (iqs / 64) * 16 + (iqs % 16);     // 0..15
             const uint scalesi = iqs / 8;                      // 0..15
             const uint qsshift = ((iqs % 64) / 16) * 2;        // 0,2,4,6
 
-            const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
+            const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303));
             const uint scales = data_a[ib].scales[scalesi];
             const vec2 dm = vec2(data_a[ib].dm);
 
-            const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
+            const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4);
 
-            buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
+            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v.xy);
+            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
 #elif defined(DATA_A_Q3_K)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -173,8 +177,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
 
-            const uint ib = idx / 128;                 // 2 values per idx
-            const uint iqs = idx % 128;                // 0..127
+            const uint ib = idx / 64;                  // 4 values per idx
+            const uint iqs = (idx % 64) * 2;           // 0,2,4..126
 
             const uint n = iqs / 32;                   // 0,1,2,3
             const uint b = (iqs % 32) / 16;            // 0,1
@@ -200,16 +204,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const float d = loadd.x * sc;
             const float m = -loadd.y * mbyte;
 
-            const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy);
+            const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F));
 
-            buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
-                                             fma(d, q.y, m));
+            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
+            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
 #elif defined(DATA_A_Q5_K)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
 
-            const uint ib = idx / 128;                 // 2 values per idx
-            const uint iqs = idx % 128;                // 0..127
+            const uint ib = idx / 64;                  // 4 values per idx
+            const uint iqs = (idx % 64) * 2;           // 0,2,4..126
 
             const uint n = iqs / 32;                   // 0,1,2,3
             const uint b = (iqs % 32) / 16;            // 0,1
@@ -236,12 +240,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const float d = loadd.x * sc;
             const float m = -loadd.y * mbyte;
 
-            const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F;
-            const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4;
-            const vec2 q = vec2(unpack8(qs | qh).xy);
+            const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F;
+            const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4;
+            const vec4 q = vec4(unpack8(qs | qh));
 
-            buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
-                                             fma(d, q.y, m));
+            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
+            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
 #elif defined(DATA_A_Q6_K)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -455,7 +459,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v.xy);
 #elif defined(DATA_A_IQ4_NL)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
-            const uint buf_idx = col * SHMEM_STRIDE + row;
+            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
 
             const uint ib = idx / 8;
             const uint iqs = idx & 0x07;
@@ -469,7 +473,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
                                                      kvalues_iq4nl[vui >> 12]);
 #elif defined(DATA_A_MXFP4)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
-            const uint buf_idx = col * SHMEM_STRIDE + row;
+            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
 
             const uint ib = idx / 8;
             const uint iqs = (idx & 0x07) * 2;
index 5b61ff9ca26689ffd1d8be5d6e3edae77cc1caf4..bbdbf9dcaaa787708950a9126570feebcf021892 100644 (file)
@@ -552,9 +552,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
 
     for (const auto& tname : type_names) {
         std::string load_vec_quant = "2";
-        if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
+        if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
             load_vec_quant = "8";
-        else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
+        else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
             load_vec_quant = "4";
 
         if (tname == "bf16") {