]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: optimize iq1 coopmat2 dequant functions (llama/12427)
authorJeff Bolz <redacted>
Wed, 19 Mar 2025 18:56:23 +0000 (13:56 -0500)
committerGeorgi Gerganov <redacted>
Thu, 27 Mar 2025 09:06:03 +0000 (11:06 +0200)
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/types.comp

index 8efe4653ffe75bf15e480a5e903b04dc041dccd0..b3fad35e21d4e34b3d99885195a0c085e17dae96 100644 (file)
@@ -311,8 +311,8 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords
     const float16_t d = bl.block.d;
     const uint idx = coordInBlock[1];
 
-    const uint ib32 = idx / 32;
-    const uint ib8 = idx / 8;
+    const uint ib32 = (idx & 0xE0) >> 5;
+    const uint ib8 = (idx & 0xF8) >> 3;
 
     const uint qh = bl.block.qh[ib32];
     const uint qs = bl.block.qs[ib8];
@@ -330,14 +330,20 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1
    block_iq1_m block;
 };
 
+layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {
+   block_iq1_m_packed64 block;
+};
+
 float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
 {
-    const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12;
-    const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12));
+    decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
     const uint idx = coordInBlock[1];
 
-    const uint ib8 = idx / 8;
-    const uint ib16 = idx / 16;
+    uvec2 scales = unpack32(bl64.block.scales);
+    const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
+
+    const uint ib8 = (idx & 0xF8) >> 3;
+    const uint ib16 = (idx & 0xF0) >> 4;
     const int i8 = int(idx % 8);
     const uint sc = bl.block.scales[ib8 / 8];
     const uint qs = bl.block.qs[ib8];
index f01179326e7fc24955f16c1172adf7e19d594cf1..789776816b75a60e03d034f4a2ade45623dcf05a 100644 (file)
@@ -2,6 +2,7 @@
 #if !defined(GGML_TYPES_COMP)
 #define GGML_TYPES_COMP
 
+#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
 #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
@@ -312,6 +313,12 @@ struct block_iq1_m {
     uint16_t scales[QUANT_K_IQ1_M/64];
 };
 
+struct block_iq1_m_packed64 {
+    uint64_t  qs[QUANT_K_IQ1_M/8/8];
+    uint64_t  qh[QUANT_K_IQ1_M/16/8];
+    uint64_t scales;
+};
+
 #if defined(DATA_A_IQ1_S)
 #define QUANT_K QUANT_K_IQ1_S
 #define QUANT_R QUANT_R_IQ1_S