]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: dynamic subgroup size for the remaining k quants (llama/10745)
authorEve <redacted>
Tue, 10 Dec 2024 19:33:23 +0000 (19:33 +0000)
committerGeorgi Gerganov <redacted>
Wed, 18 Dec 2024 10:52:16 +0000 (12:52 +0200)
* q5_k

q4_k

q3_k

q2_k

q6_k multi row example

* revert as multi row isnt faster for k quants

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp

index ad5535c11d4a1096f5567cfac957e47bb96e58c9..bb9572d76c7c7ed896a3a7f0ffa8603d9108f6d7 100644 (file)
 
 #define MAX_VK_BUFFERS 256
 
-#ifndef K_QUANTS_PER_ITERATION
-#define K_QUANTS_PER_ITERATION 1
-#else
-static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
-#endif
-
 #define VK_CHECK(err, msg)                                          \
     do {                                                            \
         vk::Result err_ = (err);                                    \
@@ -1792,10 +1786,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
 
@@ -1806,10 +1800,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1, true);
 
@@ -1820,10 +1814,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
 
index 2ec1af5c75542f9130c4ecdb8cc1b000289024a4..3894fca8260a592a45df5000190cd96f952b5a7b 100644 (file)
@@ -2,8 +2,6 @@
 #extension GL_EXT_shader_16bit_storage : require
 #extension GL_EXT_shader_8bit_storage : require
 
-#define K_QUANTS_PER_ITERATION 2
-
 #ifdef MUL_MAT_ID
 #define EXPERT_COUNT 8
 #endif
index fcf02210ea6c7ebe4ca68174e8217dfbc8a6d52d..1a5350d99ea14e9a3a9cefeac01b8ad69ae4ae27 100644 (file)
@@ -3,9 +3,11 @@
 
 #include "mul_mat_vec_base.comp"
 
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-shared FLOAT_TYPE tmp[32];
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+
+shared FLOAT_TYPE tmp[BLOCK_SIZE];
 
 void main() {
     const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@@ -20,22 +22,25 @@ void main() {
     const uint num_blocks_per_row = p.ncols / QUANT_K;
     const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
 
-    const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
+    // 16 threads are used to process each block
+    const uint it_size = gl_WorkGroupSize.x/16;
+    const uint tid = gl_LocalInvocationID.x;
+    const uint itid = tid%16;  // 0...16
+    const uint ix  = tid/16;
 
-    const uint step = 16/K_QUANTS_PER_ITERATION;            // 16 or 8
+    const uint step = 8;
 
-    const uint v_im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const uint v_in = tid - step*v_im;                      // 0...15 or 0...7
+    const uint v_im = itid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const uint v_in = itid - step*v_im;                      // 0...15 or 0...7
 
-    const uint l0 = K_QUANTS_PER_ITERATION*v_in;            // 0...15
+    const uint l0 = 2*v_in;                                  // 0...15
     const uint q_offset = 32*v_im + l0;
     const uint s_offset = 8*v_im;
     const uint y_offset = 128*v_im + l0;
 
     FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
 
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y_idx = i * QUANT_K + y_offset;
 
         f16vec2 d = data_a[ib0 + i].d;
@@ -71,7 +76,7 @@ void main() {
 
         FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
         FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
-        [[unroll]] for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+        [[unroll]] for (int l = 0; l < 2; ++l) {
             sum1 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 0) & 3),
                    fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
                    fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 2) & 3),
@@ -96,7 +101,7 @@ void main() {
 
     // sum up partial sums and write back result
     barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
+    [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
         if (tid < s) {
             tmp[tid] += tmp[tid + s];
         }
index 723fadde0d52b6033bd2e88ec578140af9d4ea60..b19c3811136bb3fa41736e1baae76e209d7c8c5e 100644 (file)
@@ -3,9 +3,11 @@
 
 #include "mul_mat_vec_base.comp"
 
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-shared FLOAT_TYPE tmp[32];
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+
+shared FLOAT_TYPE tmp[BLOCK_SIZE];
 
 void main() {
     const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@@ -20,17 +22,20 @@ void main() {
     const uint num_blocks_per_row = p.ncols / QUANT_K;
     const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
 
-    const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
+    // 16 threads are used to process each block
+    const uint it_size = gl_WorkGroupSize.x/16;
+    const uint tid = gl_LocalInvocationID.x;
+    const uint itid = tid%16;  // 0...16
+    const uint ix  = tid/16;
 
-    const uint step = 16/K_QUANTS_PER_ITERATION;            // 16 or 8
+    const uint step = 8;
 
-    const uint v_im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const uint v_in = tid - step*v_im;                      // 0...15 or 0...7
+    const uint v_im = itid/step;                            // 0 or 1. 0 computes 0..., 1 computes 128...
+    const uint v_in = itid - step*v_im;                     // 0...15 or 0...7
 
     const uint8_t m = uint8_t(1 << (4 * v_im));
 
-    const uint l0 = K_QUANTS_PER_ITERATION*v_in;            // 0...15
+    const uint l0 = 2*v_in;                                // 0...15
     const uint q_offset = 32*v_im + l0;
     const uint y_offset = 128*v_im + l0;
 
@@ -38,7 +43,7 @@ void main() {
 
     const uint s_shift = 4 * v_im;
 
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y_idx = i * QUANT_K + y_offset;
 
         const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
@@ -66,7 +71,7 @@ void main() {
         u8vec2 s10 = unpack8(s10_16);
 
         FLOAT_TYPE sum = FLOAT_TYPE(0.0);
-        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+        [[unroll]] for (int l = 0; l < 2; ++l) {
             sum = fma(FLOAT_TYPE(b0[l])   * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4)),
                   fma(FLOAT_TYPE(b32[l])  * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4)),
                   fma(FLOAT_TYPE(b64[l])  * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4)),
@@ -83,7 +88,7 @@ void main() {
 
     // sum up partial sums and write back result
     barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
+    [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
         if (tid < s) {
             tmp[tid] += tmp[tid + s];
         }
index 5846f2e86f3e33a95b09d424fb995ee6a48f8f8a..b86d28589c64c49a0b7192a561a72f2358ce496a 100644 (file)
@@ -4,11 +4,12 @@
 
 #include "mul_mat_vec_base.comp"
 
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-shared FLOAT_TYPE tmp[32];
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+
+shared FLOAT_TYPE tmp[BLOCK_SIZE];
 
-// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
 void main() {
     const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
 
@@ -22,14 +23,17 @@ void main() {
     const uint num_blocks_per_row = p.ncols / QUANT_K;
     const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
 
-    const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
+    // 16 threads are used to process each block
+    const uint it_size = gl_WorkGroupSize.x/16;
+    const uint tid = gl_LocalInvocationID.x;
+    const uint itid = tid%16;  // 0...16
+    const uint ix  = tid/16;
 
-    const uint step = 8/K_QUANTS_PER_ITERATION;             // 8 or 4
+    const uint step = 4;
 
-    const uint il = tid/step;                               // 0...3
-    const uint ir = tid - step*il;                          // 0...7 or 0...3
-    const uint n =  2 * K_QUANTS_PER_ITERATION;             // 2 or 4
+    const uint il = itid/step;                               // 0...3
+    const uint ir = itid - step*il;                          // 0...7 or 0...3
+    const uint n =  4;
 
     const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
     const uint v_in = il % 2;
@@ -40,7 +44,7 @@ void main() {
 
     FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
 
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y1_idx = i * QUANT_K + y_offset;
         const uint y2_idx = y1_idx + 128;
 
@@ -115,7 +119,7 @@ void main() {
 
     // sum up partial sums and write back result
     barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
+    [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
         if (tid < s) {
             tmp[tid] += tmp[tid + s];
         }
index b455cbd31ec913501bb03733955a16f70c75f83a..fd243cf9161fc4288897508b55fa5c001f575b8a 100644 (file)
@@ -4,9 +4,11 @@
 
 #include "mul_mat_vec_base.comp"
 
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-shared FLOAT_TYPE tmp[32];
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+
+shared FLOAT_TYPE tmp[BLOCK_SIZE];
 
 void main() {
     const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@@ -21,11 +23,14 @@ void main() {
     const uint num_blocks_per_row = p.ncols / QUANT_K;
     const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
 
-    const uint tid = gl_LocalInvocationID.x/2;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%2;  // 0 or 0, 1
+    // 16 threads are used to process each block
+    const uint it_size = gl_WorkGroupSize.x/16;
+    const uint tid = gl_LocalInvocationID.x;
+    const uint itid = tid%16;  // 0...16
+    const uint ix  = tid/16;
 
-    const uint il = tid/4;                           // 0...3
-    const uint ir = tid - 4*il;                      // 0...7 or 0...3
+    const uint il = itid/4;                          // 0...3
+    const uint ir = itid - 4*il;                     // 0...7 or 0...3
 
     const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
     const uint v_in = il % 2;
@@ -36,7 +41,7 @@ void main() {
 
     FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
 
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
+    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y1_idx = i * QUANT_K + y_offset;
         const uint y2_idx = y1_idx + 128;
 
@@ -143,7 +148,7 @@ void main() {
 
     // sum up partial sums and write back result
     barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
+    [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
         if (tid < s) {
             tmp[tid] += tmp[tid + s];
         }