]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: multi-row k quants (#10846)
authorEve <redacted>
Thu, 26 Dec 2024 15:54:44 +0000 (10:54 -0500)
committerGitHub <redacted>
Thu, 26 Dec 2024 15:54:44 +0000 (16:54 +0100)
* multi row k quant shaders!

* better row selection

* more row choices

* readjust row selection

* rm_kq=2 by default

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp

index 323ce7cf332279722dc689e6c91f212681834b81..c0a43631c87968b33d5900d906c2c94c4430ad10 100644 (file)
@@ -1855,53 +1855,58 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     // mul mat vec
 
-    // AMD GCN and Intel graphics cards perform best when the number of rows per shader is doubled
-    uint32_t rm = 1;
-    if ((device->vendor_id == VK_VENDOR_ID_AMD && device->subgroup_min_size == 64 && device->subgroup_max_size == 64) || device->vendor_id == VK_VENDOR_ID_INTEL)
-        rm = 2;
+    // the number of rows computed per shader depends on GPU model and quant
+    uint32_t rm_stdq = 1;
+    uint32_t rm_kq = 2;
+    if (device->vendor_id == VK_VENDOR_ID_AMD) {
+        if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN
+            rm_stdq = 2;
+            rm_kq = 4;
+        }
+    } else if (device->vendor_id == VK_VENDOR_ID_INTEL)
+        rm_stdq = 2;
 
-    // computing additional rows per workgroup is a benefit for Q4_0 -> Q5_1, but not for Q8_0.
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32",  mul_mat_vec_f32_f32_f32_len,  mul_mat_vec_f32_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32",  mul_mat_vec_f16_f32_f32_len,  mul_mat_vec_f16_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
 
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32",  mul_mat_vec_f32_f16_f32_len,  mul_mat_vec_f32_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32",  mul_mat_vec_f16_f16_f32_len,  mul_mat_vec_f16_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
 
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32",  mul_mat_vec_id_f32_f32_len,  mul_mat_vec_id_f32_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32",  mul_mat_vec_id_f16_f32_len,  mul_mat_vec_id_f16_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
 
     // dequant shaders
     ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16",   dequant_f32_len,  dequant_f32_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
index 1a5350d99ea14e9a3a9cefeac01b8ad69ae4ae27..138ad018411bc0fad8a32344fd114bf591168d15 100644 (file)
@@ -6,21 +6,15 @@
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+layout (constant_id = 1) const uint NUM_ROWS = 1;
 
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
-    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
-    if (row >= p.stride_d) {
-        return;
-    }
+shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
 
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
 
     const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
 
     // 16 threads are used to process each block
     const uint it_size = gl_WorkGroupSize.x/16;
@@ -38,15 +32,15 @@ void main() {
     const uint s_offset = 8*v_im;
     const uint y_offset = 128*v_im + l0;
 
-    FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
+    FLOAT_TYPE temp[NUM_ROWS];
+
+    [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+        temp[i] = FLOAT_TYPE(0);
+    }
 
     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y_idx = i * QUANT_K + y_offset;
 
-        f16vec2 d = data_a[ib0 + i].d;
-        const FLOAT_TYPE dall = d.x;
-        const FLOAT_TYPE dmin = d.y;
-
         B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
         B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
         B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
@@ -56,58 +50,84 @@ void main() {
         B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
         B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
 
-        uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
-        uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
-
-        uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
-        uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
-        uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
-        uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
-
-        uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
-        uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
-        uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
-        uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
-
-        uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
-        uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
-        uvec2 qs0 =  uvec2(unpack8(qs0_u16));
-        uvec2 qs16 = uvec2(unpack8(qs16_u16));
-
-        FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
-        FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
-        [[unroll]] for (int l = 0; l < 2; ++l) {
-            sum1 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 0) & 3),
-                   fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
-                   fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 2) & 3),
-                   fma(FLOAT_TYPE(b48[l]),  FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
-                   fma(FLOAT_TYPE(b64[l]),  FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 4) & 3),
-                   fma(FLOAT_TYPE(b80[l]),  FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
-                   fma(FLOAT_TYPE(b96[l]),  FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 6) & 3),
-                   fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
-            sum2 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_hi4[0]),
-                   fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_hi4[1]),
-                   fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_hi4[2]),
-                   fma(FLOAT_TYPE(b48[l]),  FLOAT_TYPE(s0_hi4[3]),
-                   fma(FLOAT_TYPE(b64[l]),  FLOAT_TYPE(s4_hi4[0]),
-                   fma(FLOAT_TYPE(b80[l]),  FLOAT_TYPE(s4_hi4[1]),
-                   fma(FLOAT_TYPE(b96[l]),  FLOAT_TYPE(s4_hi4[2]),
-                   fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
+        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+            const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+            f16vec2 d = data_a[ib0 + i].d;
+            const FLOAT_TYPE dall = d.x;
+            const FLOAT_TYPE dmin = d.y;
+
+            uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
+            uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
+
+            uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
+            uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
+            uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
+            uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
+
+            uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
+            uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
+            uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
+            uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
+
+            uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
+            uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
+            uvec2 qs0 =  uvec2(unpack8(qs0_u16));
+            uvec2 qs16 = uvec2(unpack8(qs16_u16));
+
+            FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
+            FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
+            [[unroll]] for (int l = 0; l < 2; ++l) {
+                sum1 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 0) & 3),
+                       fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
+                       fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 2) & 3),
+                       fma(FLOAT_TYPE(b48[l]),  FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
+                       fma(FLOAT_TYPE(b64[l]),  FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 4) & 3),
+                       fma(FLOAT_TYPE(b80[l]),  FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
+                       fma(FLOAT_TYPE(b96[l]),  FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 6) & 3),
+                       fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
+                sum2 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_hi4[0]),
+                       fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_hi4[1]),
+                       fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_hi4[2]),
+                       fma(FLOAT_TYPE(b48[l]),  FLOAT_TYPE(s0_hi4[3]),
+                       fma(FLOAT_TYPE(b64[l]),  FLOAT_TYPE(s4_hi4[0]),
+                       fma(FLOAT_TYPE(b80[l]),  FLOAT_TYPE(s4_hi4[1]),
+                       fma(FLOAT_TYPE(b96[l]),  FLOAT_TYPE(s4_hi4[2]),
+                       fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
+            }
+            temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n]));
         }
-        temp = fma(dall, sum1, fma(-dmin, sum2, temp));
     }
 
-    tmp[gl_LocalInvocationID.x] = temp;
-
     // sum up partial sums and write back result
+    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+        tmpsh[n][tid] = temp[n];
+    }
     barrier();
-    [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
+    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
         if (tid < s) {
-            tmp[tid] += tmp[tid + s];
+            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+                tmpsh[n][tid] += tmpsh[n][tid + s];
+            }
         }
         barrier();
     }
     if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
+        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
+        }
+    }
+}
+
+void main() {
+    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+    // do NUM_ROWS at a time, unless there aren't enough remaining rows
+    if (first_row + NUM_ROWS <= p.stride_d) {
+        compute_outputs(first_row, NUM_ROWS);
+    } else {
+        if (first_row >= p.stride_d) {
+            return;
+        }
+        compute_outputs(first_row, p.stride_d - first_row);
     }
 }
index b19c3811136bb3fa41736e1baae76e209d7c8c5e..82ec42d257d0c80cca0e40266ac46e491a71653e 100644 (file)
@@ -6,21 +6,15 @@
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+layout (constant_id = 1) const uint NUM_ROWS = 1;
 
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
-    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
-    if (row >= p.stride_d) {
-        return;
-    }
+shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
 
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
 
     const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
 
     // 16 threads are used to process each block
     const uint it_size = gl_WorkGroupSize.x/16;
@@ -35,19 +29,21 @@ void main() {
 
     const uint8_t m = uint8_t(1 << (4 * v_im));
 
-    const uint l0 = 2*v_in;                                // 0...15
+    const uint l0 = 2*v_in;                                 // 0...15
     const uint q_offset = 32*v_im + l0;
     const uint y_offset = 128*v_im + l0;
 
-    FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
+    FLOAT_TYPE temp[NUM_ROWS];
+
+    [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+        temp[i] = FLOAT_TYPE(0);
+    }
 
     const uint s_shift = 4 * v_im;
 
     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y_idx = i * QUANT_K + y_offset;
 
-        const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
-
         B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
         B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
         B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
@@ -57,44 +53,68 @@ void main() {
         B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
         B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
 
-        uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0];
-        uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1];
-        uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2];
-        uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3];
-        uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4];
-        uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5];
-        u8vec2 s0 = unpack8(s0_16);
-        u8vec2 s2 = unpack8(s2_16);
-        u8vec2 s4 = unpack8(s4_16);
-        u8vec2 s6 = unpack8(s6_16);
-        u8vec2 s8 = unpack8(s8_16);
-        u8vec2 s10 = unpack8(s10_16);
-
-        FLOAT_TYPE sum = FLOAT_TYPE(0.0);
-        [[unroll]] for (int l = 0; l < 2; ++l) {
-            sum = fma(FLOAT_TYPE(b0[l])   * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(b32[l])  * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(b64[l])  * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(b96[l])  * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 3)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(b16[l])  * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(b48[l])  * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(b80[l])  * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
+        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+            const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+            const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
+
+            uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0];
+            uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1];
+            uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2];
+            uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3];
+            uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4];
+            uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5];
+            u8vec2 s0 = unpack8(s0_16);
+            u8vec2 s2 = unpack8(s2_16);
+            u8vec2 s4 = unpack8(s4_16);
+            u8vec2 s6 = unpack8(s6_16);
+            u8vec2 s8 = unpack8(s8_16);
+            u8vec2 s10 = unpack8(s10_16);
+
+            FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+            [[unroll]] for (int l = 0; l < 2; ++l) {
+                sum = fma(FLOAT_TYPE(b0[l])   * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4)),
+                      fma(FLOAT_TYPE(b32[l])  * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4)),
+                      fma(FLOAT_TYPE(b64[l])  * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4)),
+                      fma(FLOAT_TYPE(b96[l])  * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 3)) != 0) ? 0 : 4)),
+                      fma(FLOAT_TYPE(b16[l])  * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
+                      fma(FLOAT_TYPE(b48[l])  * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
+                      fma(FLOAT_TYPE(b80[l])  * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
+                      fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
+            }
+            temp[n] = fma(d, sum, temp[n]);
         }
-        temp = fma(d, sum, temp);
     }
 
-    tmp[gl_LocalInvocationID.x] = temp;
-
     // sum up partial sums and write back result
+    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+        tmpsh[n][tid] = temp[n];
+    }
     barrier();
-    [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
+    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
         if (tid < s) {
-            tmp[tid] += tmp[tid + s];
+            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+                tmpsh[n][tid] += tmpsh[n][tid + s];
+            }
         }
         barrier();
     }
     if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
+        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
+        }
+    }
+}
+
+void main() {
+    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+    // do NUM_ROWS at a time, unless there aren't enough remaining rows
+    if (first_row + NUM_ROWS <= p.stride_d) {
+        compute_outputs(first_row, NUM_ROWS);
+    } else {
+        if (first_row >= p.stride_d) {
+            return;
+        }
+        compute_outputs(first_row, p.stride_d - first_row);
     }
 }
index b86d28589c64c49a0b7192a561a72f2358ce496a..677c207a842fc10b85bc4b57f0b1d858a68a42c5 100644 (file)
@@ -7,21 +7,15 @@
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+layout (constant_id = 1) const uint NUM_ROWS = 1;
 
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
-    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
-    if (row >= p.stride_d) {
-        return;
-    }
+shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
 
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
 
     const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
 
     // 16 threads are used to process each block
     const uint it_size = gl_WorkGroupSize.x/16;
@@ -31,8 +25,8 @@ void main() {
 
     const uint step = 4;
 
-    const uint il = itid/step;                               // 0...3
-    const uint ir = itid - step*il;                          // 0...7 or 0...3
+    const uint il = itid/step;                      // 0...3
+    const uint ir = itid - step*il;                 // 0...7 or 0...3
     const uint n =  4;
 
     const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
@@ -42,90 +36,116 @@ void main() {
     const uint q_offset = 32*v_im + l0;
     const uint y_offset = 64*v_im + l0;
 
-    FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
+    FLOAT_TYPE temp[NUM_ROWS];
+
+    [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+        temp[i] = FLOAT_TYPE(0);
+    }
 
     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y1_idx = i * QUANT_K + y_offset;
         const uint y2_idx = y1_idx + 128;
 
-        f16vec2 d = data_a[ib0 + i].d;
-        const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
-
-        uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
-        uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
-        uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
-        uvec4 scale0 = uvec4(unpack8(scale0_u32));
-        uvec4 scale4 = uvec4(unpack8(scale4_u32));
-        uvec4 scale8 = uvec4(unpack8(scale8_u32));
-
-        const uint32_t sc0 = (  scale0.x       & 0x3f);
-        const uint32_t sc1 = (  scale0.y       & 0x3f);
-        const uint32_t sc2 = (  scale4.x       & 0x3f);
-        const uint32_t sc3 = (  scale4.y       & 0x3f);
-        const uint32_t sc4 = (( scale8.x       & 0x0f) | ((scale0.x & 0xc0) >> 2));
-        const uint32_t sc5 = (( scale8.y       & 0x0f) | ((scale0.y & 0xc0) >> 2));
-        const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
-        const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
-
-        uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
-        uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
-
-        uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
-        uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
-        uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
-        uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
-
-        uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4));
-        uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4));
-        uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4));
-        uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4));
-
-        const uint32_t q4_0  = qs0_lo4.x;
-        const uint32_t q4_1  = qs0_lo4.y;
-        const uint32_t q4_2  = qs0_lo4.z;
-        const uint32_t q4_3  = qs0_lo4.w;
-        const uint32_t q4_4  = qs0_hi4.x;
-        const uint32_t q4_5  = qs0_hi4.y;
-        const uint32_t q4_6  = qs0_hi4.z;
-        const uint32_t q4_7  = qs0_hi4.w;
-        const uint32_t q4_8  = qs64_lo4.x;
-        const uint32_t q4_9  = qs64_lo4.y;
-        const uint32_t q4_10 = qs64_lo4.z;
-        const uint32_t q4_11 = qs64_lo4.w;
-        const uint32_t q4_12 = qs64_hi4.x;
-        const uint32_t q4_13 = qs64_hi4.y;
-        const uint32_t q4_14 = qs64_hi4.z;
-        const uint32_t q4_15 = qs64_hi4.w;
-
         B_TYPE_VEC4 by10 =  data_b_v4[(b_offset + y1_idx) / 4];
         B_TYPE_VEC4 by132 = data_b_v4[(b_offset + y1_idx) / 4 + 8];
         B_TYPE_VEC4 by20 =  data_b_v4[(b_offset + y2_idx) / 4];
         B_TYPE_VEC4 by232 = data_b_v4[(b_offset + y2_idx) / 4 + 8];
 
-        const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x),      q4_0,  fma(FLOAT_TYPE(by10.y),  q4_1,  fma(FLOAT_TYPE(by10.z),  q4_2,  FLOAT_TYPE(by10.w) *  q4_3)));
-        const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x),     q4_4,  fma(FLOAT_TYPE(by132.y), q4_5,  fma(FLOAT_TYPE(by132.z), q4_6,  FLOAT_TYPE(by132.w) * q4_7)));
-        const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x),      q4_8,  fma(FLOAT_TYPE(by20.y),  q4_9,  fma(FLOAT_TYPE(by20.z),  q4_10, FLOAT_TYPE(by20.w) *  q4_11)));
-        const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x),     q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
-        const FLOAT_TYPE smin =
-            fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
-            fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
-            fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
-            fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6,     FLOAT_TYPE(by232.w) * sc7)))))))))))))));
-        temp = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp));
+        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+            const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+            f16vec2 d = data_a[ib0 + i].d;
+            const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
+            const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+
+            uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
+            uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
+            uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
+            uvec4 scale0 = uvec4(unpack8(scale0_u32));
+            uvec4 scale4 = uvec4(unpack8(scale4_u32));
+            uvec4 scale8 = uvec4(unpack8(scale8_u32));
+
+            const uint32_t sc0 = (  scale0.x       & 0x3f);
+            const uint32_t sc1 = (  scale0.y       & 0x3f);
+            const uint32_t sc2 = (  scale4.x       & 0x3f);
+            const uint32_t sc3 = (  scale4.y       & 0x3f);
+            const uint32_t sc4 = (( scale8.x       & 0x0f) | ((scale0.x & 0xc0) >> 2));
+            const uint32_t sc5 = (( scale8.y       & 0x0f) | ((scale0.y & 0xc0) >> 2));
+            const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
+            const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
+
+            uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
+            uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
+
+            uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
+            uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
+            uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
+            uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
+
+            uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4));
+            uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4));
+            uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4));
+            uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4));
+
+            const uint32_t q4_0  = qs0_lo4.x;
+            const uint32_t q4_1  = qs0_lo4.y;
+            const uint32_t q4_2  = qs0_lo4.z;
+            const uint32_t q4_3  = qs0_lo4.w;
+            const uint32_t q4_4  = qs0_hi4.x;
+            const uint32_t q4_5  = qs0_hi4.y;
+            const uint32_t q4_6  = qs0_hi4.z;
+            const uint32_t q4_7  = qs0_hi4.w;
+            const uint32_t q4_8  = qs64_lo4.x;
+            const uint32_t q4_9  = qs64_lo4.y;
+            const uint32_t q4_10 = qs64_lo4.z;
+            const uint32_t q4_11 = qs64_lo4.w;
+            const uint32_t q4_12 = qs64_hi4.x;
+            const uint32_t q4_13 = qs64_hi4.y;
+            const uint32_t q4_14 = qs64_hi4.z;
+            const uint32_t q4_15 = qs64_hi4.w;
+
+            const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x),      q4_0,  fma(FLOAT_TYPE(by10.y),  q4_1,  fma(FLOAT_TYPE(by10.z),  q4_2,  FLOAT_TYPE(by10.w) *  q4_3)));
+            const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x),     q4_4,  fma(FLOAT_TYPE(by132.y), q4_5,  fma(FLOAT_TYPE(by132.z), q4_6,  FLOAT_TYPE(by132.w) * q4_7)));
+            const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x),      q4_8,  fma(FLOAT_TYPE(by20.y),  q4_9,  fma(FLOAT_TYPE(by20.z),  q4_10, FLOAT_TYPE(by20.w) *  q4_11)));
+            const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x),     q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
+            const FLOAT_TYPE smin =
+                fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
+                fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
+                fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
+                fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6,     FLOAT_TYPE(by232.w) * sc7)))))))))))))));
+            temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n]));
+        }
     }
 
-    tmp[gl_LocalInvocationID.x] = temp;
-
     // sum up partial sums and write back result
+    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+        tmpsh[n][tid] = temp[n];
+    }
     barrier();
-    [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
+    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
         if (tid < s) {
-            tmp[tid] += tmp[tid + s];
+            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+                tmpsh[n][tid] += tmpsh[n][tid + s];
+            }
         }
         barrier();
     }
     if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
+        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
+        }
+    }
+}
+
+void main() {
+    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+    // do NUM_ROWS at a time, unless there aren't enough remaining rows
+    if (first_row + NUM_ROWS <= p.stride_d) {
+        compute_outputs(first_row, NUM_ROWS);
+    } else {
+        if (first_row >= p.stride_d) {
+            return;
+        }
+        compute_outputs(first_row, p.stride_d - first_row);
     }
 }
index fd243cf9161fc4288897508b55fa5c001f575b8a..ed3c25d891c79e39b5bf643c3358d316368fab56 100644 (file)
@@ -7,21 +7,15 @@
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+layout (constant_id = 1) const uint NUM_ROWS = 1;
 
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
-    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
-    if (row >= p.stride_d) {
-        return;
-    }
+shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
 
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
 
     const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
 
     // 16 threads are used to process each block
     const uint it_size = gl_WorkGroupSize.x/16;
@@ -39,74 +33,16 @@ void main() {
     const uint q_offset = 32*v_im + l0;
     const uint y_offset = 64*v_im + l0;
 
-    FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
+    FLOAT_TYPE temp[NUM_ROWS];
+
+    [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+        temp[i] = FLOAT_TYPE(0);
+    }
 
     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y1_idx = i * QUANT_K + y_offset;
         const uint y2_idx = y1_idx + 128;
 
-        f16vec2 d = data_a[ib0 + i].d;
-        const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
-
-        uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
-        uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
-        uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
-        uvec4 scale0 = uvec4(unpack8(scale0_u32));
-        uvec4 scale4 = uvec4(unpack8(scale4_u32));
-        uvec4 scale8 = uvec4(unpack8(scale8_u32));
-
-        const uint32_t sc0 = (  scale0.x       & 0x3f);
-        const uint32_t sc1 = (  scale0.y       & 0x3f);
-        const uint32_t sc2 = (  scale4.x       & 0x3f);
-        const uint32_t sc3 = (  scale4.y       & 0x3f);
-        const uint32_t sc4 = (( scale8.x       & 0x0f) | ((scale0.x & 0xc0) >> 2));
-        const uint32_t sc5 = (( scale8.y       & 0x0f) | ((scale0.y & 0xc0) >> 2));
-        const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
-        const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
-
-        uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
-        uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
-
-        uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
-        uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
-        uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
-        uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
-
-        uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
-
-        uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
-        uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
-        uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0;
-        uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
-
-        qs0_16_u32_lo4 += qs0_16_lo4_offset16;
-        qs0_16_u32_hi4 += qs0_16_hi4_offset16;
-        qs64_80_u32_lo4 += qs64_80_lo4_offset16;
-        qs64_80_u32_hi4 += qs64_80_hi4_offset16;
-
-        uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
-        uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
-        uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
-        uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
-
-        const uint32_t q4_0  = qs0_16_lo4.x;
-        const uint32_t q4_1  = qs0_16_lo4.y;
-        const uint32_t q4_2  = qs0_16_lo4.z;
-        const uint32_t q4_3  = qs0_16_lo4.w;
-        const uint32_t q4_4  = qs0_16_hi4.x;
-        const uint32_t q4_5  = qs0_16_hi4.y;
-        const uint32_t q4_6  = qs0_16_hi4.z;
-        const uint32_t q4_7  = qs0_16_hi4.w;
-        const uint32_t q4_8  = qs64_80_lo4.x;
-        const uint32_t q4_9  = qs64_80_lo4.y;
-        const uint32_t q4_10 = qs64_80_lo4.z;
-        const uint32_t q4_11 = qs64_80_lo4.w;
-        const uint32_t q4_12 = qs64_80_hi4.x;
-        const uint32_t q4_13 = qs64_80_hi4.y;
-        const uint32_t q4_14 = qs64_80_hi4.z;
-        const uint32_t q4_15 = qs64_80_hi4.w;
-
         B_TYPE_VEC2 by10 =  data_b_v2[(b_offset + y1_idx) / 2];
         B_TYPE_VEC2 by116 = data_b_v2[(b_offset + y1_idx) / 2 + 8];
         B_TYPE_VEC2 by132 = data_b_v2[(b_offset + y1_idx) / 2 + 16];
@@ -116,45 +52,129 @@ void main() {
         B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
         B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
 
-        const FLOAT_TYPE sx =
-          fma(FLOAT_TYPE(by10.x), q4_0,
-          fma(FLOAT_TYPE(by10.y), q4_1,
-          fma(FLOAT_TYPE(by116.x), q4_2,
-             FLOAT_TYPE(by116.y) * q4_3)));
-        const FLOAT_TYPE sy =
-          fma(FLOAT_TYPE(by132.x), q4_4,
-          fma(FLOAT_TYPE(by132.y), q4_5,
-          fma(FLOAT_TYPE(by148.x), q4_6,
-             FLOAT_TYPE(by148.y) * q4_7)));
-        const FLOAT_TYPE sz =
-          fma(FLOAT_TYPE(by20.x), q4_8,
-          fma(FLOAT_TYPE(by20.y), q4_9,
-          fma(FLOAT_TYPE(by216.x), q4_10,
-             FLOAT_TYPE(by216.y) * q4_11)));
-        const FLOAT_TYPE sw =
-          fma(FLOAT_TYPE(by232.x), q4_12,
-          fma(FLOAT_TYPE(by232.y), q4_13,
-          fma(FLOAT_TYPE(by248.x), q4_14,
-             FLOAT_TYPE(by248.y) * q4_15)));
-        const FLOAT_TYPE smin =
-          fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
-          fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
-          fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
-              (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
-        temp = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp));
+        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+            const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+            f16vec2 d = data_a[ib0 + i].d;
+            const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
+            const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+
+            uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
+            uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
+            uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
+            uvec4 scale0 = uvec4(unpack8(scale0_u32));
+            uvec4 scale4 = uvec4(unpack8(scale4_u32));
+            uvec4 scale8 = uvec4(unpack8(scale8_u32));
+
+            const uint32_t sc0 = (  scale0.x       & 0x3f);
+            const uint32_t sc1 = (  scale0.y       & 0x3f);
+            const uint32_t sc2 = (  scale4.x       & 0x3f);
+            const uint32_t sc3 = (  scale4.y       & 0x3f);
+            const uint32_t sc4 = (( scale8.x       & 0x0f) | ((scale0.x & 0xc0) >> 2));
+            const uint32_t sc5 = (( scale8.y       & 0x0f) | ((scale0.y & 0xc0) >> 2));
+            const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
+            const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
+
+            uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
+            uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
+
+            uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
+            uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
+            uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
+            uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
+
+            uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
+
+            uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
+            uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
+            uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0;
+            uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
+
+            qs0_16_u32_lo4 += qs0_16_lo4_offset16;
+            qs0_16_u32_hi4 += qs0_16_hi4_offset16;
+            qs64_80_u32_lo4 += qs64_80_lo4_offset16;
+            qs64_80_u32_hi4 += qs64_80_hi4_offset16;
+
+            uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
+            uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
+            uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
+            uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
+
+            const uint32_t q4_0  = qs0_16_lo4.x;
+            const uint32_t q4_1  = qs0_16_lo4.y;
+            const uint32_t q4_2  = qs0_16_lo4.z;
+            const uint32_t q4_3  = qs0_16_lo4.w;
+            const uint32_t q4_4  = qs0_16_hi4.x;
+            const uint32_t q4_5  = qs0_16_hi4.y;
+            const uint32_t q4_6  = qs0_16_hi4.z;
+            const uint32_t q4_7  = qs0_16_hi4.w;
+            const uint32_t q4_8  = qs64_80_lo4.x;
+            const uint32_t q4_9  = qs64_80_lo4.y;
+            const uint32_t q4_10 = qs64_80_lo4.z;
+            const uint32_t q4_11 = qs64_80_lo4.w;
+            const uint32_t q4_12 = qs64_80_hi4.x;
+            const uint32_t q4_13 = qs64_80_hi4.y;
+            const uint32_t q4_14 = qs64_80_hi4.z;
+            const uint32_t q4_15 = qs64_80_hi4.w;
+
+            const FLOAT_TYPE sx =
+              fma(FLOAT_TYPE(by10.x), q4_0,
+              fma(FLOAT_TYPE(by10.y), q4_1,
+              fma(FLOAT_TYPE(by116.x), q4_2,
+                 FLOAT_TYPE(by116.y) * q4_3)));
+            const FLOAT_TYPE sy =
+              fma(FLOAT_TYPE(by132.x), q4_4,
+              fma(FLOAT_TYPE(by132.y), q4_5,
+              fma(FLOAT_TYPE(by148.x), q4_6,
+                 FLOAT_TYPE(by148.y) * q4_7)));
+            const FLOAT_TYPE sz =
+              fma(FLOAT_TYPE(by20.x), q4_8,
+              fma(FLOAT_TYPE(by20.y), q4_9,
+              fma(FLOAT_TYPE(by216.x), q4_10,
+                 FLOAT_TYPE(by216.y) * q4_11)));
+            const FLOAT_TYPE sw =
+              fma(FLOAT_TYPE(by232.x), q4_12,
+              fma(FLOAT_TYPE(by232.y), q4_13,
+              fma(FLOAT_TYPE(by248.x), q4_14,
+                 FLOAT_TYPE(by248.y) * q4_15)));
+            const FLOAT_TYPE smin =
+              fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
+              fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
+              fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
+                  (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
+            temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n]));
+        }
     }
 
-    tmp[gl_LocalInvocationID.x] = temp;
-
     // sum up partial sums and write back result
+    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+        tmpsh[n][tid] = temp[n];
+    }
     barrier();
-    [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
+    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
         if (tid < s) {
-            tmp[tid] += tmp[tid + s];
+            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+                tmpsh[n][tid] += tmpsh[n][tid + s];
+            }
         }
         barrier();
     }
     if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
+        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
+        }
+    }
+}
+
+void main() {
+    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+    // do NUM_ROWS at a time, unless there aren't enough remaining rows
+    if (first_row + NUM_ROWS <= p.stride_d) {
+        compute_outputs(first_row, NUM_ROWS);
+    } else {
+        if (first_row >= p.stride_d) {
+            return;
+        }
+        compute_outputs(first_row, p.stride_d - first_row);
     }
 }
index 760aff85499f40e720f470b8f73d527bd5dec918..fab4ff5ff054e51cfb619f9e9cbe878d1d73406a 100644 (file)
@@ -7,21 +7,15 @@
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+layout (constant_id = 1) const uint NUM_ROWS = 1;
 
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
-    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
-    if (row >= p.stride_d) {
-        return;
-    }
+shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
 
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
 
     const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
 
     // 16 threads are used to process each block
     const uint it_size = gl_WorkGroupSize.x/16;
@@ -42,69 +36,95 @@ void main() {
     const uint s_offset  =  8*v_im + is;
     const uint y_offset = 128*v_im + l0;
 
-    FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
-        const uint y_idx   = i * QUANT_K + y_offset;
-
-        const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
-
-        FLOAT_TYPE scales[4];
-        scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]);
-        scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]);
-        scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]);
-        scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]);
-
-        uint32_t ql0_u32 =  uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
-        uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
-
-        uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
-        uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
-        uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
-        uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
-
-        uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
-        uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
-        uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
-        uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0;
-        uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
+    FLOAT_TYPE temp[NUM_ROWS];
 
-        uint32_t q0_u32 = ql0_u32_lo4  | qh0_u32;
-        uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
-        uint32_t q2_u32 = ql0_u32_hi4  | qh4_u32;
-        uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
+    [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+        temp[i] = FLOAT_TYPE(0);
+    }
 
-        uvec4 q0 = uvec4(unpack8(q0_u32));
-        uvec4 q1 = uvec4(unpack8(q1_u32));
-        uvec4 q2 = uvec4(unpack8(q2_u32));
-        uvec4 q3 = uvec4(unpack8(q3_u32));
+    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
+        const uint y_idx = i * QUANT_K + y_offset;
 
         B_TYPE_VEC4 by0  = data_b_v4[(b_offset + y_idx) / 4];
         B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8];
         B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16];
         B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24];
 
-        FLOAT_TYPE sum = FLOAT_TYPE(0.0);
-        [[unroll]] for (int l = 0; l < 4; ++l) {
-            sum = fma(FLOAT_TYPE(by0[l])  * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
-                  fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
-                  fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
-                  fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
+        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+            const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+            const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
+
+            FLOAT_TYPE scales[4];
+            scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]);
+            scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]);
+            scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]);
+            scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]);
+
+            uint32_t ql0_u32 =  uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
+            uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
+
+            uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
+            uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
+            uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
+            uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
+
+            uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
+            uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
+            uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
+            uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0;
+            uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
+
+            uint32_t q0_u32 = ql0_u32_lo4  | qh0_u32;
+            uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
+            uint32_t q2_u32 = ql0_u32_hi4  | qh4_u32;
+            uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
+
+            uvec4 q0 = uvec4(unpack8(q0_u32));
+            uvec4 q1 = uvec4(unpack8(q1_u32));
+            uvec4 q2 = uvec4(unpack8(q2_u32));
+            uvec4 q3 = uvec4(unpack8(q3_u32));
+
+            FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+            [[unroll]] for (int l = 0; l < 4; ++l) {
+                sum = fma(FLOAT_TYPE(by0[l])  * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
+                      fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
+                      fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
+                      fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
+            }
+            temp[n] += sum * d;
         }
-        temp += sum * d;
     }
 
-    tmp[gl_LocalInvocationID.x] = temp;
     // sum up partial sums and write back result
-
+    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+        tmpsh[n][tid] = temp[n];
+    }
     barrier();
-    [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
+    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
         if (tid < s) {
-            tmp[tid] += tmp[tid + s];
+            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+                tmpsh[n][tid] += tmpsh[n][tid + s];
+            }
         }
         barrier();
     }
     if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
+        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
+        }
+    }
+}
+
+void main() {
+    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+    // do NUM_ROWS at a time, unless there aren't enough remaining rows
+    if (first_row + NUM_ROWS <= p.stride_d) {
+        compute_outputs(first_row, NUM_ROWS);
+    } else {
+        if (first_row >= p.stride_d) {
+            return;
+        }
+        compute_outputs(first_row, p.stride_d - first_row);
     }
 }