]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: further optimize q5_k mul_mat_vec (llama/10479)
authorJeff Bolz <redacted>
Wed, 27 Nov 2024 07:21:59 +0000 (01:21 -0600)
committerGeorgi Gerganov <redacted>
Tue, 3 Dec 2024 19:05:37 +0000 (21:05 +0200)
src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp

index 22a6bfae4fb615ff04a0ac961a0f7bf695e95779..b455cbd31ec913501bb03733955a16f70c75f83a 100644 (file)
@@ -34,9 +34,6 @@ void main() {
     const uint q_offset = 32*v_im + l0;
     const uint y_offset = 64*v_im + l0;
 
-    const uint8_t hm1 = uint8_t(1 << (2*v_im));
-    const uint8_t hm2 = uint8_t(hm1 << 4);
-
     FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
 
     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
@@ -71,6 +68,18 @@ void main() {
         uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
         uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
 
+        uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
+
+        uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
+        uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
+        uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0;
+        uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
+
+        qs0_16_u32_lo4 += qs0_16_lo4_offset16;
+        qs0_16_u32_hi4 += qs0_16_hi4_offset16;
+        qs64_80_u32_lo4 += qs64_80_lo4_offset16;
+        qs64_80_u32_hi4 += qs64_80_hi4_offset16;
+
         uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
         uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
         uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
@@ -102,31 +111,26 @@ void main() {
         B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
         B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
 
-        uint32_t qh0 = data_a_packed16[ib0 + i].qh[l0 / 2];
-        uint32_t qh1 = qh0 >> 8;
-        uint32_t qh16 = data_a_packed16[ib0 + i].qh[l0 / 2 + 8];
-        uint32_t qh17 = qh16 >> 8;
-
         const FLOAT_TYPE sx =
-          fma(FLOAT_TYPE(by10.x), (q4_0 + (((qh0 & hm1) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(by10.y), (q4_1 + (((qh1 & hm1) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(by116.x), (q4_2 + (((qh16 & hm1) != 0) ? 16 : 0)),
-             FLOAT_TYPE(by116.y) * (q4_3 + (((qh17 & hm1) != 0) ? 16 : 0)))));
+          fma(FLOAT_TYPE(by10.x), q4_0,
+          fma(FLOAT_TYPE(by10.y), q4_1,
+          fma(FLOAT_TYPE(by116.x), q4_2,
+             FLOAT_TYPE(by116.y) * q4_3)));
         const FLOAT_TYPE sy =
-          fma(FLOAT_TYPE(by132.x), (q4_4 + (((qh0 & (hm1 << 1)) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(by132.y), (q4_5 + (((qh1 & (hm1 << 1)) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(by148.x), (q4_6 + (((qh16 & (hm1 << 1)) != 0) ? 16 : 0)),
-             FLOAT_TYPE(by148.y) * (q4_7 + (((qh17 & (hm1 << 1)) != 0) ? 16 : 0)))));
+          fma(FLOAT_TYPE(by132.x), q4_4,
+          fma(FLOAT_TYPE(by132.y), q4_5,
+          fma(FLOAT_TYPE(by148.x), q4_6,
+             FLOAT_TYPE(by148.y) * q4_7)));
         const FLOAT_TYPE sz =
-          fma(FLOAT_TYPE(by20.x), (q4_8  + (((qh0 & hm2) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(by20.y), (q4_9  + (((qh1 & hm2) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(by216.x), (q4_10 + (((qh16 & hm2) != 0) ? 16 : 0)),
-             FLOAT_TYPE(by216.y) * (q4_11 + (((qh17 & hm2) != 0) ? 16 : 0)))));
+          fma(FLOAT_TYPE(by20.x), q4_8,
+          fma(FLOAT_TYPE(by20.y), q4_9,
+          fma(FLOAT_TYPE(by216.x), q4_10,
+             FLOAT_TYPE(by216.y) * q4_11)));
         const FLOAT_TYPE sw =
-          fma(FLOAT_TYPE(by232.x), (q4_12 + (((qh0 & (hm2 << 1)) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(by232.y), (q4_13 + (((qh1 & (hm2 << 1)) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(by248.x), (q4_14 + (((qh16 & (hm2 << 1)) != 0) ? 16 : 0)),
-             FLOAT_TYPE(by248.y) * (q4_15 + (((qh17 & (hm2 << 1)) != 0) ? 16 : 0)))));
+          fma(FLOAT_TYPE(by232.x), q4_12,
+          fma(FLOAT_TYPE(by232.y), q4_13,
+          fma(FLOAT_TYPE(by248.x), q4_14,
+             FLOAT_TYPE(by248.y) * q4_15)));
         const FLOAT_TYPE smin =
           fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
           fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,