]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: matmul dequantization improvements (#12015)
authorEve <redacted>
Fri, 28 Feb 2025 07:20:08 +0000 (07:20 +0000)
committerGitHub <redacted>
Fri, 28 Feb 2025 07:20:08 +0000 (08:20 +0100)
* faster dequant for old quants

* dont use unpack for iq4_nl

* vec2 unpack for q8

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/types.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 10318e87660741d02ba750fb45f9a5de10c87414..8835c442ecfd8470558d91d578d146aca2ac10e5 100644 (file)
@@ -82,9 +82,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
     return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
 }
 vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
-    uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
-    uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
-    return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8));
+    const i8vec2 v0 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2]);
+    const i8vec2 v1 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2 + 1]);
+    return vec4(v0.x, v0.y, v1.x, v1.y);
 }
 #endif
 
index 4770469eddcab979b4cb16023dcfeddfe85d2676..4ccbe613af2ce3e8d515c0ecb3c1ffde94ad21cd 100644 (file)
@@ -92,7 +92,7 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2
     const uint iqs = idx;
 
     // Load 16b and select the byte for this element
-    int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1];
+    int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];
     float16_t ret = float16_t(qs) * d;
     return ret;
 }
index 39657195cfc8d37fe36aefa5c960311de36716e9..a8fd93fdeadee861588812849e0eaed322a42868 100644 (file)
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+#if defined(A_TYPE_PACKED16)
+layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
+#endif
+#if defined(A_TYPE_PACKED32)
+layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
+#endif
+
 layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
 layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 
@@ -243,74 +250,100 @@ void main() {
 #endif
 #elif defined(DATA_A_Q4_0)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
-
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
-
-            const float d = float(data_a[ib].d);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
-
-            buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
+
+            const uint ib = idx / 4;
+            const uint iqs = idx & 0x03;
+
+            const float d = float(data_a_packed16[ib].d);
+            const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
+            const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
+            const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
+
+            buf_a[buf_idx     ] = FLOAT_TYPE(v0.x);
+            buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
+            buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
+            buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
+            buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
+            buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
+            buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
+            buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
 #elif defined(DATA_A_Q4_1)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
-
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
-
-            const float d = float(data_a[ib].d);
-            const float m = float(data_a[ib].m);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
-
-            buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
+
+            const uint ib = idx / 4;
+            const uint iqs = idx & 0x03;
+
+            const float d = float(data_a_packed16[ib].d);
+            const float m = float(data_a_packed16[ib].m);
+            const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
+            const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
+            const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
+
+            buf_a[buf_idx     ] = FLOAT_TYPE(v0.x);
+            buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
+            buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
+            buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
+            buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
+            buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
+            buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
+            buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
 #elif defined(DATA_A_Q5_0)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
 
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
+            const uint ib = idx / 8;
+            const uint iqs = idx & 0x07;
 
-            const float d = float(data_a[ib].d);
-            const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
-            const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
+            const float d = float(data_a_packed16[ib].d);
+            const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
+            const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
+            const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
+
+            const uint vui = uint(data_a_packed16[ib].qs[iqs]);
+            const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
 
             buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
+            buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+            buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
 #elif defined(DATA_A_Q5_1)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
 
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
+            const uint ib = idx / 8;
+            const uint iqs = idx & 0x07;
 
-            const float d = float(data_a[ib].d);
-            const float m = float(data_a[ib].m);
-            const uint uint_qh = data_a[ib].qh;
-            const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
+            const float d = float(data_a_packed16[ib].d);
+            const float m = float(data_a_packed16[ib].m);
+            const uint uint_qh = data_a_packed16[ib].qh;
+            const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
+            const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
+
+            const uint vui = uint(data_a_packed16[ib].qs[iqs]);
+            const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
 
             buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
+            buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+            buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
 #elif defined(DATA_A_Q8_0)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
-            const uint ib = idx / 16;
-            const uint iqs = (idx & 0xF) * 2;
+            const uint ib = idx / 8;
+            const uint iqs = idx & 0x07;
 
-            const float d = float(data_a[ib].d);
-            const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
+            const float d = float(data_a_packed16[ib].d);
+            const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]);
+            const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]);
+            const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
 
             buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
             buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+            buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
+            buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
 #elif defined(DATA_A_Q2_K)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
@@ -623,17 +656,18 @@ void main() {
             buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
 #elif defined(DATA_A_IQ4_NL)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
 
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
+            const uint ib = idx / 8;
+            const uint iqs = idx & 0x07;
 
-            const float d = float(data_a[ib].d);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
+            const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
+            const uint vui = uint(data_a_packed16[ib].qs[iqs]);
 
-            buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+            buf_a[buf_idx     ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
+            buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
+            buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
+            buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
 #endif
         }
         [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
index dfa16cda516bde4236d54687aee2ea239a4e3f61..907067d7fa56e80cae051c1866a69c7e98270931 100644 (file)
@@ -139,7 +139,7 @@ struct block_q8_0
 struct block_q8_0_packed16
 {
     float16_t d;
-    uint16_t qs[32/2];
+    int16_t qs[32/2];
 };
 
 #if defined(DATA_A_Q8_0)
index c5e0bba82b26040c2119eacbbe4809daee70837b..4a81505565e89b1dff9a39465b79b00ea2bf6e3d 100644 (file)
@@ -325,11 +325,17 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
     string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
 
     for (const auto& tname : type_names) {
+        std::string load_vec_quant = "2";
+        if ((tname == "q4_0") || (tname == "q4_1"))
+            load_vec_quant = "8";
+        else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
+            load_vec_quant = "4";
+
         std::string data_a_key = "DATA_A_" + to_uppercase(tname);
         // For unaligned, load one at a time for f32/f16, or two at a time for quants
-        std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
+        std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
         // For aligned matmul loads
-        std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
+        std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;
 
         // don't generate f32 variants for coopmat2
         if (!coopmat2) {