]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: scale caching for k quants + misc fixes (llama/11081)
authorEve <redacted>
Wed, 15 Jan 2025 19:50:13 +0000 (19:50 +0000)
committerGeorgi Gerganov <redacted>
Mon, 3 Feb 2025 20:00:57 +0000 (22:00 +0200)
* q6_k scale caching

* 16 bit unpack

* q4_k test (slow)

* revert it

* q3_k

* q2_k

* little stuff

* try precalculating products of a and q2_k scales

* Revert "try precalculating products of a and q2_k scales"

This reverts commit 65110b81f23f66331a50c6e889a7c1ab9470a86b.

* unpack should be u16, add vim swap to gitignore (about time)

* better q4_k scales

* q5_k

* better q6_k with separate paths for all threads and partial threads in use, plus some more optimizations

* q2_k better dequant

* q3_k optimizations

* q3_k use hmask simd from cpu avx version

* make the caches happy

* q3_k separate out calculation

* q2_k separate out

* little stuff

* use calc_superblock everywhere

* q2_k optimize scale calculation

* more barriers

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp

index 6a9b9b2d132a0febd37950ce6772d8d2cd6d3a57..8cdc640e80e31bb6badda522578814a88f8a898b 100644 (file)
@@ -5,6 +5,80 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
+shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
+shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];
+
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
+    const uint y_idx = i * QUANT_K + y_offset;
+
+    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+
+        barrier();
+        if (!all_threads) { // when we don't have enough blocks to use all threads
+            if (i < num_blocks_per_row) {
+                const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
+                sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
+                sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
+            }
+            barrier();
+
+            if (i >= num_blocks_per_row)
+                continue;
+        } else {
+            const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
+            sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
+            sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
+            barrier();
+        }
+
+        const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
+        const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
+        const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
+        const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
+        const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
+
+        vec2 d = vec2(data_a[ib0 + i].d);
+        const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
+        const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+
+        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+            vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]);
+            vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]);
+            vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
+            vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
+            vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
+            vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
+            vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
+            vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
+
+            FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
+            FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
+            [[unroll]] for (int l = 0; l < 2; ++l) {
+                sum1 = fma(FLOAT_TYPE(b0[l]),   sccache1[ix][    8*v_im] * qs_u32_0[l  ],
+                       fma(FLOAT_TYPE(b16[l]),  sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
+                       fma(FLOAT_TYPE(b32[l]),  sccache1[ix][2 + 8*v_im] * qs_u32_2[l  ],
+                       fma(FLOAT_TYPE(b48[l]),  sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
+                       fma(FLOAT_TYPE(b64[l]),  sccache1[ix][4 + 8*v_im] * qs_u32_4[l  ],
+                       fma(FLOAT_TYPE(b80[l]),  sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
+                       fma(FLOAT_TYPE(b96[l]),  sccache1[ix][6 + 8*v_im] * qs_u32_6[l  ],
+                       fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
+                sum2 = fma(FLOAT_TYPE(b0[l]),   sccache2[ix][    8*v_im],
+                       fma(FLOAT_TYPE(b16[l]),  sccache2[ix][1 + 8*v_im],
+                       fma(FLOAT_TYPE(b32[l]),  sccache2[ix][2 + 8*v_im],
+                       fma(FLOAT_TYPE(b48[l]),  sccache2[ix][3 + 8*v_im],
+                       fma(FLOAT_TYPE(b64[l]),  sccache2[ix][4 + 8*v_im],
+                       fma(FLOAT_TYPE(b80[l]),  sccache2[ix][5 + 8*v_im],
+                       fma(FLOAT_TYPE(b96[l]),  sccache2[ix][6 + 8*v_im],
+                       fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
+            }
+            temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
+        }
+    }
+}
+
 void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
@@ -14,88 +88,28 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     // 16 threads are used to process each block
     const uint it_size = gl_WorkGroupSize.x/16;
     const uint tid = gl_LocalInvocationID.x;
-    const uint itid = tid%16;  // 0...16
-    const uint ix  = tid/16;
-
-    const uint step = 8;
+    const uint itid = tid%16;  // 0...15
+    const uint ix = tid/16;
 
-    const uint v_im = itid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const uint v_in = itid - step*v_im;                      // 0...15 or 0...7
+    const uint v_im = itid/8;                                // 0 or 1. 0 computes 0..., 1 computes 128...
+    const uint v_in = itid - 8*v_im;                         // 0...7
 
     const uint l0 = 2*v_in;                                  // 0...15
     const uint q_offset = 32*v_im + l0;
-    const uint s_offset = 8*v_im;
     const uint y_offset = 128*v_im + l0;
 
-    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
-
     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
             temp[j][i] = FLOAT_TYPE(0);
         }
     }
 
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
-        const uint y_idx = i * QUANT_K + y_offset;
-
-        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-            const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
-            vec2 d = vec2(data_a[ib0 + i].d);
-            const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
-            const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
-
-            uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
-            uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
-
-            uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
-            uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
-            uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
-            uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
-
-            uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
-            uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
-            uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
-            uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
-
-            uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
-            uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
-            uvec2 qs0 =  uvec2(unpack8(qs0_u16));
-            uvec2 qs16 = uvec2(unpack8(qs16_u16));
-
-            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
-                vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]);
-                vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]);
-                vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
-                vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
-                vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
-                vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
-                vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
-                vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
-
-                FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
-                FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
-                [[unroll]] for (int l = 0; l < 2; ++l) {
-                    sum1 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 0) & 3),
-                           fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
-                           fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 2) & 3),
-                           fma(FLOAT_TYPE(b48[l]),  FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
-                           fma(FLOAT_TYPE(b64[l]),  FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 4) & 3),
-                           fma(FLOAT_TYPE(b80[l]),  FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
-                           fma(FLOAT_TYPE(b96[l]),  FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 6) & 3),
-                           fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
-                    sum2 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_hi4[0]),
-                           fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_hi4[1]),
-                           fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_hi4[2]),
-                           fma(FLOAT_TYPE(b48[l]),  FLOAT_TYPE(s0_hi4[3]),
-                           fma(FLOAT_TYPE(b64[l]),  FLOAT_TYPE(s4_hi4[0]),
-                           fma(FLOAT_TYPE(b80[l]),  FLOAT_TYPE(s4_hi4[1]),
-                           fma(FLOAT_TYPE(b96[l]),  FLOAT_TYPE(s4_hi4[2]),
-                           fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
-                }
-                temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
-            }
-        }
-    }
+    const uint nbr_par_th = num_blocks_per_row%it_size;
+    const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
+    uint i0 = 0;
+    [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
+        calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
+    calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
 
     reduce_result(temp, d_offset, first_row, num_rows, tid);
 }
index 96ef50fdda2a3cb9668d9a3a5c2a5f6b04e5c1e8..3116fad165be0b994bb6ecec67ce555b34123562 100644 (file)
@@ -5,6 +5,74 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
+shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8];
+
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
+    const uint y_idx = i * QUANT_K + y_offset;
+
+    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+
+        if (!all_threads) { // when we don't have enough blocks to use all threads
+            barrier();
+            if (i < num_blocks_per_row)
+                sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
+            barrier();
+
+            if (i >= num_blocks_per_row)
+                continue;
+        }
+
+        const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16));
+        const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> (    v_im4)) << 2));
+        const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2));
+        const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2));
+        const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2));
+
+        // 0, 1, 16, 17
+        uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8);
+        qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16;
+        const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
+        const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
+        const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
+        const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
+
+        if (all_threads) {
+            barrier();
+            sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
+            barrier();
+        }
+
+        const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
+
+        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+            vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]);
+            vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]);
+            vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
+            vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
+            vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
+            vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
+            vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
+            vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
+
+            FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+            [[unroll]] for (int l = 0; l < 2; ++l) {
+                sum = fma(FLOAT_TYPE(  b0[l]) * sccache[ix][v_im][0], qs_u32_0[l  ] - hmk_0[l  ],
+                      fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
+                      fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l  ] - hmk_1[l  ],
+                      fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
+                      fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l  ] - hmk_2[l  ],
+                      fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
+                      fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l  ] - hmk_3[l  ],
+                      fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
+            }
+            temp[j][n] = fma(d, sum, temp[j][n]);
+        }
+    }
+}
+
 void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
@@ -14,76 +82,37 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     // 16 threads are used to process each block
     const uint it_size = gl_WorkGroupSize.x/16;
     const uint tid = gl_LocalInvocationID.x;
-    const uint itid = tid%16;  // 0...16
-    const uint ix  = tid/16;
-
-    const uint step = 8;
+    const uint itid = tid%16;  // 0...15
+    const uint ix = tid/16;
+    const uint itid8 = itid%8;
 
-    const uint v_im = itid/step;                            // 0 or 1. 0 computes 0..., 1 computes 128...
-    const uint v_in = itid - step*v_im;                     // 0...15 or 0...7
+    const uint v_im = itid/8;                               // 0 or 1. 0 computes 0..., 1 computes 128...
+    const uint v_im4 = v_im*4;
+    const uint v_in = itid - 8*v_im;                        // 0...7
 
-    const uint8_t m = uint8_t(1 << (4 * v_im));
+    const uint32_t m = 0x01010101 << (4 * v_im);
+    uint32_t hm_m[4];
+    [[unroll]] for (uint j = 0; j < 4; ++j)
+        hm_m[j] = m << j;
 
     const uint l0 = 2*v_in;                                 // 0...15
     const uint q_offset = 32*v_im + l0;
     const uint y_offset = 128*v_im + l0;
 
-    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
-
     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
             temp[j][i] = FLOAT_TYPE(0);
         }
     }
 
-    const uint s_shift = 4 * v_im;
-
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
-        const uint y_idx = i * QUANT_K + y_offset;
-
-        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-            const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
-            const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
-
-            uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0];
-            uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1];
-            uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2];
-            uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3];
-            uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4];
-            uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5];
-            u8vec2 s0 = unpack8(s0_16);
-            u8vec2 s2 = unpack8(s2_16);
-            u8vec2 s4 = unpack8(s4_16);
-            u8vec2 s6 = unpack8(s6_16);
-            u8vec2 s8 = unpack8(s8_16);
-            u8vec2 s10 = unpack8(s10_16);
-
-            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
-
-                vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]);
-                vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]);
-                vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
-                vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
-                vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
-                vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
-                vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
-                vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
-
-                FLOAT_TYPE sum = FLOAT_TYPE(0.0);
-                [[unroll]] for (int l = 0; l < 2; ++l) {
-                    sum = fma(FLOAT_TYPE(b0[l])   * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4)),
-                          fma(FLOAT_TYPE(b32[l])  * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4)),
-                          fma(FLOAT_TYPE(b64[l])  * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4)),
-                          fma(FLOAT_TYPE(b96[l])  * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 3)) != 0) ? 0 : 4)),
-                          fma(FLOAT_TYPE(b16[l])  * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
-                          fma(FLOAT_TYPE(b48[l])  * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
-                          fma(FLOAT_TYPE(b80[l])  * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
-                          fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
-                }
-                temp[j][n] = fma(d, sum, temp[j][n]);
-            }
-        }
-    }
+    const uint s_shift = v_im4 + 2*(itid8/4);
+
+    const uint nbr_par_th = num_blocks_per_row%it_size;
+    const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
+    uint i0 = 0;
+    [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
+        calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
+    calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
 
     reduce_result(temp, d_offset, first_row, num_rows, tid);
 }
index f97eb8744fb18b84d04aa26c95e59d4b16579422..f9cde064887a8c1a0680431f2e93b8d363ab950a 100644 (file)
@@ -6,6 +6,86 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
+    const uint y1_idx = i * QUANT_K + y_offset;
+    const uint y2_idx = y1_idx + 128;
+
+    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        vec2 d = vec2(data_a[ib0 + i].d);
+        const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
+        const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+
+        const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
+        const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
+        const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
+
+        const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
+        const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
+        const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
+        const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
+
+        const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
+        const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
+        const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
+        const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
+        const FLOAT_TYPE sc4 = scale8_f.x;
+        const FLOAT_TYPE sc5 = scale8_f.y;
+        const FLOAT_TYPE sc6 = scale8_f.z;
+        const FLOAT_TYPE sc7 = scale8_f.w;
+
+        const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
+        const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
+
+        const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
+        const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
+        const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
+        const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
+
+        const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4));
+        const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4));
+        const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4));
+        const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4));
+
+        const FLOAT_TYPE q4_0  = qs0_lo4.x;
+        const FLOAT_TYPE q4_1  = qs0_lo4.y;
+        const FLOAT_TYPE q4_2  = qs0_lo4.z;
+        const FLOAT_TYPE q4_3  = qs0_lo4.w;
+        const FLOAT_TYPE q4_4  = qs0_hi4.x;
+        const FLOAT_TYPE q4_5  = qs0_hi4.y;
+        const FLOAT_TYPE q4_6  = qs0_hi4.z;
+        const FLOAT_TYPE q4_7  = qs0_hi4.w;
+        const FLOAT_TYPE q4_8  = qs64_lo4.x;
+        const FLOAT_TYPE q4_9  = qs64_lo4.y;
+        const FLOAT_TYPE q4_10 = qs64_lo4.z;
+        const FLOAT_TYPE q4_11 = qs64_lo4.w;
+        const FLOAT_TYPE q4_12 = qs64_hi4.x;
+        const FLOAT_TYPE q4_13 = qs64_hi4.y;
+        const FLOAT_TYPE q4_14 = qs64_hi4.z;
+        const FLOAT_TYPE q4_15 = qs64_hi4.w;
+
+        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+            vec4 by10 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4    ]);
+            vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
+            vec4 by20 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4    ]);
+            vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
+
+            const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x),      q4_0,  fma(FLOAT_TYPE(by10.y),  q4_1,  fma(FLOAT_TYPE(by10.z),  q4_2,  FLOAT_TYPE(by10.w) *  q4_3)));
+            const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x),     q4_4,  fma(FLOAT_TYPE(by132.y), q4_5,  fma(FLOAT_TYPE(by132.z), q4_6,  FLOAT_TYPE(by132.w) * q4_7)));
+            const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x),      q4_8,  fma(FLOAT_TYPE(by20.y),  q4_9,  fma(FLOAT_TYPE(by20.z),  q4_10, FLOAT_TYPE(by20.w) *  q4_11)));
+            const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x),     q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
+            const FLOAT_TYPE smin =
+                fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
+                fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
+                fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
+                fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6,     FLOAT_TYPE(by232.w) * sc7)))))))))))))));
+            temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
+        }
+    }
+}
+
 void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
@@ -15,13 +95,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     // 16 threads are used to process each block
     const uint it_size = gl_WorkGroupSize.x/16;
     const uint tid = gl_LocalInvocationID.x;
-    const uint itid = tid%16;  // 0...16
-    const uint ix  = tid/16;
+    const uint itid = tid%16;  // 0...15
+    const uint ix = tid/16;
 
-    const uint step = 4;
-
-    const uint il = itid/step;                      // 0...3
-    const uint ir = itid - step*il;                 // 0...7 or 0...3
+    const uint il = itid/4;                         // 0...3
+    const uint ir = itid - 4*il;                    // 0...3
     const uint n =  4;
 
     const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
@@ -31,89 +109,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     const uint q_offset = 32*v_im + l0;
     const uint y_offset = 64*v_im + l0;
 
-    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
-
     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
             temp[j][i] = FLOAT_TYPE(0);
         }
     }
 
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
-        const uint y1_idx = i * QUANT_K + y_offset;
-        const uint y2_idx = y1_idx + 128;
-
-        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-            const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
-            vec2 d = vec2(data_a[ib0 + i].d);
-            const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
-            const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
-
-            uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
-            uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
-            uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
-            uvec4 scale0 = uvec4(unpack8(scale0_u32));
-            uvec4 scale4 = uvec4(unpack8(scale4_u32));
-            uvec4 scale8 = uvec4(unpack8(scale8_u32));
-
-            const uint32_t sc0 = (  scale0.x       & 0x3f);
-            const uint32_t sc1 = (  scale0.y       & 0x3f);
-            const uint32_t sc2 = (  scale4.x       & 0x3f);
-            const uint32_t sc3 = (  scale4.y       & 0x3f);
-            const uint32_t sc4 = (( scale8.x       & 0x0f) | ((scale0.x & 0xc0) >> 2));
-            const uint32_t sc5 = (( scale8.y       & 0x0f) | ((scale0.y & 0xc0) >> 2));
-            const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
-            const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
-
-            uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
-            uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
-
-            uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
-            uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
-            uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
-            uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
-
-            uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4));
-            uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4));
-            uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4));
-            uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4));
-
-            const uint32_t q4_0  = qs0_lo4.x;
-            const uint32_t q4_1  = qs0_lo4.y;
-            const uint32_t q4_2  = qs0_lo4.z;
-            const uint32_t q4_3  = qs0_lo4.w;
-            const uint32_t q4_4  = qs0_hi4.x;
-            const uint32_t q4_5  = qs0_hi4.y;
-            const uint32_t q4_6  = qs0_hi4.z;
-            const uint32_t q4_7  = qs0_hi4.w;
-            const uint32_t q4_8  = qs64_lo4.x;
-            const uint32_t q4_9  = qs64_lo4.y;
-            const uint32_t q4_10 = qs64_lo4.z;
-            const uint32_t q4_11 = qs64_lo4.w;
-            const uint32_t q4_12 = qs64_hi4.x;
-            const uint32_t q4_13 = qs64_hi4.y;
-            const uint32_t q4_14 = qs64_hi4.z;
-            const uint32_t q4_15 = qs64_hi4.w;
-
-            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
-                vec4 by10 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4    ]);
-                vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
-                vec4 by20 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4    ]);
-                vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
-
-                const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x),      q4_0,  fma(FLOAT_TYPE(by10.y),  q4_1,  fma(FLOAT_TYPE(by10.z),  q4_2,  FLOAT_TYPE(by10.w) *  q4_3)));
-                const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x),     q4_4,  fma(FLOAT_TYPE(by132.y), q4_5,  fma(FLOAT_TYPE(by132.z), q4_6,  FLOAT_TYPE(by132.w) * q4_7)));
-                const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x),      q4_8,  fma(FLOAT_TYPE(by20.y),  q4_9,  fma(FLOAT_TYPE(by20.z),  q4_10, FLOAT_TYPE(by20.w) *  q4_11)));
-                const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x),     q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
-                const FLOAT_TYPE smin =
-                    fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
-                    fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
-                    fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
-                    fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6,     FLOAT_TYPE(by232.w) * sc7)))))))))))))));
-                temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
-            }
-        }
-    }
+    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)
+        calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);
 
     reduce_result(temp, d_offset, first_row, num_rows, tid);
 }
index 79d7db0e3e64b33eb3fd8f474715464b0dc28d09..6c84ef3cde3ff5e8bb6f9d57bb9ae8f7d3655a96 100644 (file)
@@ -6,6 +6,118 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
+    const uint y1_idx = i * QUANT_K + y_offset;
+    const uint y2_idx = y1_idx + 128;
+
+    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        vec2 d = vec2(data_a[ib0 + i].d);
+        const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
+        const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+
+        const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
+        const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
+        const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
+
+        const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
+        const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
+        const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
+        const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
+
+        const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
+        const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
+        const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
+        const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
+        const FLOAT_TYPE sc4 = scale8_f.x;
+        const FLOAT_TYPE sc5 = scale8_f.y;
+        const FLOAT_TYPE sc6 = scale8_f.z;
+        const FLOAT_TYPE sc7 = scale8_f.w;
+
+        const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
+        const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
+
+        uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
+        uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
+        uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
+        uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
+
+        const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
+
+        const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
+        const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
+        const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010);
+        const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
+
+        qs0_16_u32_lo4 += qs0_16_lo4_offset16;
+        qs0_16_u32_hi4 += qs0_16_hi4_offset16;
+        qs64_80_u32_lo4 += qs64_80_lo4_offset16;
+        qs64_80_u32_hi4 += qs64_80_hi4_offset16;
+
+        const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4));
+        const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4));
+        const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4));
+        const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4));
+
+        const FLOAT_TYPE q4_0  = qs0_16_lo4.x;
+        const FLOAT_TYPE q4_1  = qs0_16_lo4.y;
+        const FLOAT_TYPE q4_2  = qs0_16_lo4.z;
+        const FLOAT_TYPE q4_3  = qs0_16_lo4.w;
+        const FLOAT_TYPE q4_4  = qs0_16_hi4.x;
+        const FLOAT_TYPE q4_5  = qs0_16_hi4.y;
+        const FLOAT_TYPE q4_6  = qs0_16_hi4.z;
+        const FLOAT_TYPE q4_7  = qs0_16_hi4.w;
+        const FLOAT_TYPE q4_8  = qs64_80_lo4.x;
+        const FLOAT_TYPE q4_9  = qs64_80_lo4.y;
+        const FLOAT_TYPE q4_10 = qs64_80_lo4.z;
+        const FLOAT_TYPE q4_11 = qs64_80_lo4.w;
+        const FLOAT_TYPE q4_12 = qs64_80_hi4.x;
+        const FLOAT_TYPE q4_13 = qs64_80_hi4.y;
+        const FLOAT_TYPE q4_14 = qs64_80_hi4.z;
+        const FLOAT_TYPE q4_15 = qs64_80_hi4.w;
+
+        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+            vec2 by10 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2     ]);
+            vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 +  8]);
+            vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
+            vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
+            vec2 by20 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2     ]);
+            vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 +  8]);
+            vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
+            vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
+
+            const FLOAT_TYPE sx =
+              fma(FLOAT_TYPE(by10.x), q4_0,
+              fma(FLOAT_TYPE(by10.y), q4_1,
+              fma(FLOAT_TYPE(by116.x), q4_2,
+                 FLOAT_TYPE(by116.y) * q4_3)));
+            const FLOAT_TYPE sy =
+              fma(FLOAT_TYPE(by132.x), q4_4,
+              fma(FLOAT_TYPE(by132.y), q4_5,
+              fma(FLOAT_TYPE(by148.x), q4_6,
+                 FLOAT_TYPE(by148.y) * q4_7)));
+            const FLOAT_TYPE sz =
+              fma(FLOAT_TYPE(by20.x), q4_8,
+              fma(FLOAT_TYPE(by20.y), q4_9,
+              fma(FLOAT_TYPE(by216.x), q4_10,
+                 FLOAT_TYPE(by216.y) * q4_11)));
+            const FLOAT_TYPE sw =
+              fma(FLOAT_TYPE(by232.x), q4_12,
+              fma(FLOAT_TYPE(by232.y), q4_13,
+              fma(FLOAT_TYPE(by248.x), q4_14,
+                 FLOAT_TYPE(by248.y) * q4_15)));
+            const FLOAT_TYPE smin =
+              fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
+              fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
+              fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
+                  (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
+            temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
+        }
+    }
+}
+
 void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
@@ -15,11 +127,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     // 16 threads are used to process each block
     const uint it_size = gl_WorkGroupSize.x/16;
     const uint tid = gl_LocalInvocationID.x;
-    const uint itid = tid%16;  // 0...16
-    const uint ix  = tid/16;
+    const uint itid = tid%16;  // 0...15
+    const uint ix = tid/16;
 
     const uint il = itid/4;                          // 0...3
-    const uint ir = itid - 4*il;                     // 0...7 or 0...3
+    const uint ir = itid - 4*il;                     // 0...3
 
     const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
     const uint v_in = il % 2;
@@ -28,121 +140,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     const uint q_offset = 32*v_im + l0;
     const uint y_offset = 64*v_im + l0;
 
-    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
-
     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
             temp[j][i] = FLOAT_TYPE(0);
         }
     }
 
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
-        const uint y1_idx = i * QUANT_K + y_offset;
-        const uint y2_idx = y1_idx + 128;
-
-        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-            const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
-            vec2 d = vec2(data_a[ib0 + i].d);
-            const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
-            const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
-
-            uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
-            uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
-            uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
-            uvec4 scale0 = uvec4(unpack8(scale0_u32));
-            uvec4 scale4 = uvec4(unpack8(scale4_u32));
-            uvec4 scale8 = uvec4(unpack8(scale8_u32));
-
-            const uint32_t sc0 = (  scale0.x       & 0x3f);
-            const uint32_t sc1 = (  scale0.y       & 0x3f);
-            const uint32_t sc2 = (  scale4.x       & 0x3f);
-            const uint32_t sc3 = (  scale4.y       & 0x3f);
-            const uint32_t sc4 = (( scale8.x       & 0x0f) | ((scale0.x & 0xc0) >> 2));
-            const uint32_t sc5 = (( scale8.y       & 0x0f) | ((scale0.y & 0xc0) >> 2));
-            const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
-            const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
-
-            uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
-            uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
-
-            uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
-            uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
-            uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
-            uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
-
-            uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
-
-            uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
-            uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
-            uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0;
-            uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
-
-            qs0_16_u32_lo4 += qs0_16_lo4_offset16;
-            qs0_16_u32_hi4 += qs0_16_hi4_offset16;
-            qs64_80_u32_lo4 += qs64_80_lo4_offset16;
-            qs64_80_u32_hi4 += qs64_80_hi4_offset16;
-
-            uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
-            uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
-            uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
-            uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
-
-            const uint32_t q4_0  = qs0_16_lo4.x;
-            const uint32_t q4_1  = qs0_16_lo4.y;
-            const uint32_t q4_2  = qs0_16_lo4.z;
-            const uint32_t q4_3  = qs0_16_lo4.w;
-            const uint32_t q4_4  = qs0_16_hi4.x;
-            const uint32_t q4_5  = qs0_16_hi4.y;
-            const uint32_t q4_6  = qs0_16_hi4.z;
-            const uint32_t q4_7  = qs0_16_hi4.w;
-            const uint32_t q4_8  = qs64_80_lo4.x;
-            const uint32_t q4_9  = qs64_80_lo4.y;
-            const uint32_t q4_10 = qs64_80_lo4.z;
-            const uint32_t q4_11 = qs64_80_lo4.w;
-            const uint32_t q4_12 = qs64_80_hi4.x;
-            const uint32_t q4_13 = qs64_80_hi4.y;
-            const uint32_t q4_14 = qs64_80_hi4.z;
-            const uint32_t q4_15 = qs64_80_hi4.w;
-
-            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
-                vec2 by10 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2     ]);
-                vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 +  8]);
-                vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
-                vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
-                vec2 by20 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2     ]);
-                vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 +  8]);
-                vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
-                vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
-
-                const FLOAT_TYPE sx =
-                  fma(FLOAT_TYPE(by10.x), q4_0,
-                  fma(FLOAT_TYPE(by10.y), q4_1,
-                  fma(FLOAT_TYPE(by116.x), q4_2,
-                     FLOAT_TYPE(by116.y) * q4_3)));
-                const FLOAT_TYPE sy =
-                  fma(FLOAT_TYPE(by132.x), q4_4,
-                  fma(FLOAT_TYPE(by132.y), q4_5,
-                  fma(FLOAT_TYPE(by148.x), q4_6,
-                     FLOAT_TYPE(by148.y) * q4_7)));
-                const FLOAT_TYPE sz =
-                  fma(FLOAT_TYPE(by20.x), q4_8,
-                  fma(FLOAT_TYPE(by20.y), q4_9,
-                  fma(FLOAT_TYPE(by216.x), q4_10,
-                     FLOAT_TYPE(by216.y) * q4_11)));
-                const FLOAT_TYPE sw =
-                  fma(FLOAT_TYPE(by232.x), q4_12,
-                  fma(FLOAT_TYPE(by232.y), q4_13,
-                  fma(FLOAT_TYPE(by248.x), q4_14,
-                     FLOAT_TYPE(by248.y) * q4_15)));
-                const FLOAT_TYPE smin =
-                  fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
-                  fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
-                  fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
-                      (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
-                temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
-            }
-        }
-    }
+    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)
+        calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);
 
     reduce_result(temp, d_offset, first_row, num_rows, tid);
 }
index 041fd27c12b54aa987e6c0a1f6f6b76e2f6fb474..f05f96b5efb9d3ef6b2256514f0099be45bbfefc 100644 (file)
@@ -6,7 +6,77 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
+shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
+
+FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
+    const uint y_idx = i * QUANT_K + y_offset;
+
+    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+
+        if (!all_threads) { // when we don't have enough blocks to use all threads
+            barrier();
+            if (i < num_blocks_per_row)
+                sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
+            barrier();
+
+            if (i >= num_blocks_per_row)
+                continue;
+        }
+
+        const uint32_t ql0_u32 =  uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
+        const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
+
+        const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
+        const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
+        const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
+        const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
+
+        const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
+        const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
+        const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
+        const uint32_t qh4_u32 = (qh_u32 & 0x30303030);
+        const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
+
+        const uint32_t q0_u32 = ql0_u32_lo4  | qh0_u32;
+        const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
+        const uint32_t q2_u32 = ql0_u32_hi4  | qh4_u32;
+        const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
+
+        const vec4 q0 = vec4(unpack8(q0_u32)) - 32;
+        const vec4 q1 = vec4(unpack8(q1_u32)) - 32;
+        const vec4 q2 = vec4(unpack8(q2_u32)) - 32;
+        const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
+
+        if (all_threads) {
+            barrier();
+            sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
+            barrier();
+        }
+
+        const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
+
+        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+            vec4 by0  = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4     ]);
+            vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 +  8]);
+            vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
+            vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
+
+            FLOAT_TYPE sum[4] = {0, 0, 0, 0};
+            [[unroll]] for (uint l = 0; l < 4; ++l) {
+                sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]);
+                sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]);
+                sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
+                sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
+            }
+            temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
+        }
+    }
+}
+
+void compute_outputs(const uint first_row, const uint num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
 
@@ -15,13 +85,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     // 16 threads are used to process each block
     const uint it_size = gl_WorkGroupSize.x/16;
     const uint tid = gl_LocalInvocationID.x;
-    const uint itid = tid%16;  // 0...16
-    const uint ix  = tid/16;
+    const uint itid = tid%16;  // 0...15
+    const uint ix = tid/16;
 
-    const uint step = 8;
-
-    const uint v_im = itid/step;                            // 0 or 1. 0 computes 0..., 1 computes 128...
-    const uint v_in = itid - step*v_im;                     // 0...15 or 0...7
+    const uint v_im = itid/8;                               // 0 or 1. 0 computes 0..., 1 computes 128...
+    const uint v_in = itid - 8*v_im;                        // 0...7
 
     const uint l0 = 4 * v_in;                               // 0, 4, 8, ..., 28
     const uint is = v_in / 4;
@@ -31,68 +99,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     const uint s_offset  =  8*v_im + is;
     const uint y_offset = 128*v_im + l0;
 
-    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
-
     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
             temp[j][i] = FLOAT_TYPE(0);
         }
     }
 
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
-        const uint y_idx = i * QUANT_K + y_offset;
-
-        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-            const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
-            const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
-
-            FLOAT_TYPE scales[4];
-            scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]);
-            scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]);
-            scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]);
-            scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]);
-
-            uint32_t ql0_u32 =  uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
-            uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
-
-            uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
-            uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
-            uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
-            uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
-
-            uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
-            uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
-            uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
-            uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0;
-            uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
-
-            uint32_t q0_u32 = ql0_u32_lo4  | qh0_u32;
-            uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
-            uint32_t q2_u32 = ql0_u32_hi4  | qh4_u32;
-            uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
-
-            uvec4 q0 = uvec4(unpack8(q0_u32));
-            uvec4 q1 = uvec4(unpack8(q1_u32));
-            uvec4 q2 = uvec4(unpack8(q2_u32));
-            uvec4 q3 = uvec4(unpack8(q3_u32));
-
-            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
-                vec4 by0  = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4     ]);
-                vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 +  8]);
-                vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
-                vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
-
-                FLOAT_TYPE sum = FLOAT_TYPE(0.0);
-                [[unroll]] for (int l = 0; l < 4; ++l) {
-                    sum = fma(FLOAT_TYPE(by0[l])  * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
-                          fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
-                          fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
-                          fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
-                }
-                temp[j][n] += sum * d;
-            }
-        }
-    }
+    const uint nbr_par_th = num_blocks_per_row%it_size;
+    const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
+    uint i0 = 0;
+    [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
+        calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
+    calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
 
     reduce_result(temp, d_offset, first_row, num_rows, tid);
 }