]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: dequantize iq4_xs 4 at a time (#20657)
authorEve <redacted>
Thu, 19 Mar 2026 10:32:04 +0000 (10:32 +0000)
committerGitHub <redacted>
Thu, 19 Mar 2026 10:32:04 +0000 (11:32 +0100)
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index ce7f2d699a212b22b50c434e97d4c42df1c3cdea..3f494eb4d5ae074c184e084f724af5046298d1da 100644 (file)
@@ -444,19 +444,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
 
-            const uint ib = idx / 128;                  // 2 values per idx
-            const uint ib32 = (idx % 128) / 16;         // 0..7
-            const uint iq = 16 * ib32 + 2 * (idx % 8);
+            const uint ib = idx / 64;            // 4 values per idx
+            const uint ib32 = (idx % 64) / 8;    // 0..7
+            const uint iq = 4 * ib32 + (idx % 4);
 
             const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
             const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
-            const uint qshift = (idx & 8) >> 1;
-            u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy;
+            const uint qshift = idx & 4;
+            u8vec4 qs = unpack8((uint(data_a_packed32[ib].qs[iq]) >> qshift) & 0x0F0F0F0F);
 
             const float d = float(data_a[ib].d);
-            const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
+            const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
 
             buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v.xy);
+            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
 #elif defined(DATA_A_IQ4_NL)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
index 4b00ba3debb8091eb0c83c67c41d63d4bf031d5b..abd2a9c36fa1fd9a5834f8077460825e4f7ba6fe 100644 (file)
@@ -554,7 +554,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
         std::string load_vec_quant = "2";
         if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
             load_vec_quant = "8";
-        else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
+        else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4"))
             load_vec_quant = "4";
 
         if (tname == "bf16") {