]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: optimize mul_mat for small values of N (#10991)
authorJeff Bolz <redacted>
Mon, 30 Dec 2024 17:27:11 +0000 (11:27 -0600)
committerGitHub <redacted>
Mon, 30 Dec 2024 17:27:11 +0000 (18:27 +0100)
Make the mul_mat_vec shaders support N>1 (as a spec constant, NUM_COLS) where
the batch_strides are overloaded to hold the row strides. Put the loads from the
B matrix in the innermost loop because it should cache better.

Share some code for reducing the result values to memory in mul_mat_vec_base.

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp
tests/test-backend-ops.cpp

index 8e47e79ae328451b89f1d090768937f4a883f6f9..020e612801f0d818a6467861bc68d6ae2d107d80 100644 (file)
@@ -145,6 +145,8 @@ class vk_perf_logger;
 #endif
 static void ggml_vk_destroy_buffer(vk_buffer& buf);
 
+static constexpr uint32_t mul_mat_vec_max_cols = 8;
+
 struct vk_device_struct {
     std::mutex mutex;
 
@@ -202,8 +204,8 @@ struct vk_device_struct {
     vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
 
     vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
-    vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
-    vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT];
+    vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
+    vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
     vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
 
     vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
@@ -1866,33 +1868,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
     } else if (device->vendor_id == VK_VENDOR_ID_INTEL)
         rm_stdq = 2;
 
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32",  mul_mat_vec_f32_f32_f32_len,  mul_mat_vec_f32_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32",  mul_mat_vec_f16_f32_f32_len,  mul_mat_vec_f16_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
-
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32",  mul_mat_vec_f32_f16_f32_len,  mul_mat_vec_f32_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32",  mul_mat_vec_f16_f16_f32_len,  mul_mat_vec_f16_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
+    for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1),  mul_mat_vec_f32_f32_f32_len,  mul_mat_vec_f32_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1),  mul_mat_vec_f16_f32_f32_len,  mul_mat_vec_f16_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1),  mul_mat_vec_f32_f16_f32_len,  mul_mat_vec_f32_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1),  mul_mat_vec_f16_f16_f32_len,  mul_mat_vec_f16_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
+    }
 
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32",  mul_mat_vec_id_f32_f32_len,  mul_mat_vec_id_f32_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32",  mul_mat_vec_id_f16_f32_len,  mul_mat_vec_id_f16_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -2892,9 +2896,10 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
     return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
 }
 
-static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
+static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) {
     VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
     GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
+    GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
 
     switch (a_type) {
         case GGML_TYPE_F32:
@@ -2915,7 +2920,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
             return nullptr;
     }
 
-    return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
+    return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1];
 }
 
 static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
@@ -3925,8 +3930,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
     const uint64_t ne12 = src1->ne[2];
     const uint64_t ne13 = src1->ne[3];
 
-    GGML_ASSERT(ne11 == 1);
-
     const uint64_t ne20 = dst->ne[0];
     const uint64_t ne21 = dst->ne[1];
     const uint64_t ne22 = dst->ne[2];
@@ -3935,6 +3938,11 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
     const uint64_t r2 = ne12 / ne02;
     const uint64_t r3 = ne13 / ne03;
 
+    // batch_n indicates that we need to compute a few vector results, and this assumes
+    // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides.
+    GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1);
+    bool batch_n = ne11 > 1;
+
     ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
     ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
     ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
@@ -3985,7 +3993,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
     } else {
         to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
     }
-    vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type);
+    vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11);
     GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
     GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
     GGML_ASSERT(dmmv != nullptr);
@@ -4057,8 +4065,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
         ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
     }
 
-    uint32_t stride_batch_x = ne00*ne01;
-    uint32_t stride_batch_y = ne10*ne11;
+    // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
+    uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;
+    uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11);
+    uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21);
 
     if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
         stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
@@ -4081,7 +4091,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
     // compute
     const vk_mat_vec_push_constants pc = {
         (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
-        stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
+        stride_batch_x, stride_batch_y, stride_batch_d,
         (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
     };
     ggml_vk_sync_buffers(subctx);
@@ -4261,7 +4271,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
     } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
                !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
         ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
-    } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
+    // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
+    // when ne12 and ne13 are one.
+    } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
+               (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
         ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
     } else {
         ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
index 187c31916d16f3b04825bdcccc1be2e691c5ad5b..24875cdcf4c988424d85ac6ba136a78206980f1a 100644 (file)
@@ -9,9 +9,6 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-layout (constant_id = 1) const uint NUM_ROWS = 1;
-
 #if !defined(DATA_A_F32) && !defined(DATA_A_F16)
 #define K_PER_ITER 8
 #else
@@ -21,70 +18,70 @@ layout (constant_id = 1) const uint NUM_ROWS = 1;
 
 uint a_offset, b_offset, d_offset, y_offset;
 
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
-void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
+void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
 {
-    const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
-    const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
-    const uint iybs = col - col%QUANT_K; // y block start index
+    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+        const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
+        const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
+        const uint iybs = col - col%QUANT_K; // y block start index
 
 #if K_PER_ITER == 8
 #if QUANT_R == 2
-    const B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
-    const B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
-    const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
-    const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
+        const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
+        const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4];
+        const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
+        const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
 #else
-    const vec4 bv0 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4]);
-    const vec4 bv1 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4 + 1]);
+        const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
+        const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
 #endif
 #else
-    // Check if the second of the pair of elements is OOB, and don't fetch B or
-    // accumulate it. We still fetch a pair of elements for A, which is fine for
-    // quantized formats since they'll be within the same block. We should
-    // probably skip fetching the second element for F16/F32, but as of now we
-    // still do.
-    const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
-
-    FLOAT_TYPE b0 = 0, b1 = 0;
-    b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]);
-    if (!OOB) {
-        b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
-    }
+        // Check if the second of the pair of elements is OOB, and don't fetch B or
+        // accumulate it. We still fetch a pair of elements for A, which is fine for
+        // quantized formats since they'll be within the same block. We should
+        // probably skip fetching the second element for F16/F32, but as of now we
+        // still do.
+        const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
+
+        FLOAT_TYPE b0 = 0, b1 = 0;
+        b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
+        if (!OOB) {
+            b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
+        }
 #endif
-    uint ibi = first_row*p.ncols;
-    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        const uint ib = (ibi + col)/QUANT_K; // block index
-        ibi += p.ncols;
+        uint ibi = first_row*p.ncols;
+        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+            const uint ib = (ibi + col)/QUANT_K; // block index
+            ibi += p.ncols;
 
 #if K_PER_ITER == 8
-        vec4 v = dequantize4(ib, iqs, a_offset);
-        vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
+            vec4 v = dequantize4(ib, iqs, a_offset);
+            vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
 
-        const vec2 dm = get_dm(ib, a_offset);
-        if (dm.y != 0) { // quant has min component
-            v = v * dm.x + dm.y;
-            v2 = v2 * dm.x + dm.y;
-        }
+            const vec2 dm = get_dm(ib, a_offset);
+            if (dm.y != 0) { // quant has min component
+                v = v * dm.x + dm.y;
+                v2 = v2 * dm.x + dm.y;
+            }
 
-        // matrix multiplication
-        FLOAT_TYPE rowtmp = dot(bv0, v);
-        rowtmp += dot(bv1, v2);
+            // matrix multiplication
+            FLOAT_TYPE rowtmp = dot(bv0, v);
+            rowtmp += dot(bv1, v2);
 
-        if (dm.y == 0)
-            rowtmp *= dm.x;
+            if (dm.y == 0)
+                rowtmp *= dm.x;
 
-        temp[n] += rowtmp;
+            temp[j][n] += rowtmp;
 #else
-        const vec2 v = dequantize(ib, iqs, a_offset);
+            const vec2 v = dequantize(ib, iqs, a_offset);
 
-        // matrix multiplication
-        temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
-        if (!OOB) {
-            temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
-        }
+            // matrix multiplication
+            temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
+            if (!OOB) {
+                temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
+            }
 #endif
+        }
     }
 }
 
@@ -96,10 +93,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
 
     y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
 
-    FLOAT_TYPE temp[NUM_ROWS];
+    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
 
-    for (uint i = 0; i < NUM_ROWS; ++i) {
-        temp[i] = FLOAT_TYPE(0);
+    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+            temp[j][i] = FLOAT_TYPE(0);
+        }
     }
 
     uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
@@ -131,24 +130,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
         i++;
     }
 
-    // sum up partial sums and write back result
-    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        tmpsh[n][tid] = temp[n];
-    }
-    barrier();
-    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
-        if (tid < s) {
-            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-                tmpsh[n][tid] += tmpsh[n][tid + s];
-            }
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
-        }
-    }
+    reduce_result(temp, d_offset, first_row, num_rows, tid);
 }
 
 void main() {
index 3894fca8260a592a45df5000190cd96f952b5a7b..903753c7e2ec5fc3fed4ab5fc8f9809a9a73678b 100644 (file)
@@ -83,3 +83,36 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
             batch_idx * p.batch_stride_d;
 #endif
 }
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+layout (constant_id = 1) const uint NUM_ROWS = 1;
+layout (constant_id = 2) const uint NUM_COLS = 1;
+
+shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
+
+void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
+    // sum up partial sums and write back result
+    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+            tmpsh[j][n][tid] = temp[j][n];
+        }
+    }
+    barrier();
+    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
+        if (tid < s) {
+            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+                [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+                    tmpsh[j][n][tid] += tmpsh[j][n][tid + s];
+                }
+            }
+        }
+        barrier();
+    }
+    if (tid == 0) {
+        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+                data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
+            }
+        }
+    }
+}
index 138ad018411bc0fad8a32344fd114bf591168d15..93421344624164fc6ccb5272992efcf5748a4b3c 100644 (file)
@@ -5,11 +5,6 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-layout (constant_id = 1) const uint NUM_ROWS = 1;
-
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
 void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
@@ -32,24 +27,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     const uint s_offset = 8*v_im;
     const uint y_offset = 128*v_im + l0;
 
-    FLOAT_TYPE temp[NUM_ROWS];
+    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
 
-    [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
-        temp[i] = FLOAT_TYPE(0);
+    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+            temp[j][i] = FLOAT_TYPE(0);
+        }
     }
 
     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y_idx = i * QUANT_K + y_offset;
 
-        B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
-        B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
-        B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
-        B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24];
-        B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32];
-        B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40];
-        B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
-        B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
-
         [[unroll]] for (uint n = 0; n < num_rows; ++n) {
             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
             f16vec2 d = data_a[ib0 + i].d;
@@ -74,48 +62,42 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
             uvec2 qs0 =  uvec2(unpack8(qs0_u16));
             uvec2 qs16 = uvec2(unpack8(qs16_u16));
 
-            FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
-            FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
-            [[unroll]] for (int l = 0; l < 2; ++l) {
-                sum1 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 0) & 3),
-                       fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
-                       fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 2) & 3),
-                       fma(FLOAT_TYPE(b48[l]),  FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
-                       fma(FLOAT_TYPE(b64[l]),  FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 4) & 3),
-                       fma(FLOAT_TYPE(b80[l]),  FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
-                       fma(FLOAT_TYPE(b96[l]),  FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 6) & 3),
-                       fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
-                sum2 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_hi4[0]),
-                       fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_hi4[1]),
-                       fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_hi4[2]),
-                       fma(FLOAT_TYPE(b48[l]),  FLOAT_TYPE(s0_hi4[3]),
-                       fma(FLOAT_TYPE(b64[l]),  FLOAT_TYPE(s4_hi4[0]),
-                       fma(FLOAT_TYPE(b80[l]),  FLOAT_TYPE(s4_hi4[1]),
-                       fma(FLOAT_TYPE(b96[l]),  FLOAT_TYPE(s4_hi4[2]),
-                       fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
+            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+                B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
+                B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
+                B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
+                B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
+                B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
+                B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
+                B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
+                B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
+
+                FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
+                FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
+                [[unroll]] for (int l = 0; l < 2; ++l) {
+                    sum1 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 0) & 3),
+                           fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
+                           fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 2) & 3),
+                           fma(FLOAT_TYPE(b48[l]),  FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
+                           fma(FLOAT_TYPE(b64[l]),  FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 4) & 3),
+                           fma(FLOAT_TYPE(b80[l]),  FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
+                           fma(FLOAT_TYPE(b96[l]),  FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 6) & 3),
+                           fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
+                    sum2 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_hi4[0]),
+                           fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_hi4[1]),
+                           fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_hi4[2]),
+                           fma(FLOAT_TYPE(b48[l]),  FLOAT_TYPE(s0_hi4[3]),
+                           fma(FLOAT_TYPE(b64[l]),  FLOAT_TYPE(s4_hi4[0]),
+                           fma(FLOAT_TYPE(b80[l]),  FLOAT_TYPE(s4_hi4[1]),
+                           fma(FLOAT_TYPE(b96[l]),  FLOAT_TYPE(s4_hi4[2]),
+                           fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
+                }
+                temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
             }
-            temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n]));
         }
     }
 
-    // sum up partial sums and write back result
-    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        tmpsh[n][tid] = temp[n];
-    }
-    barrier();
-    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
-        if (tid < s) {
-            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-                tmpsh[n][tid] += tmpsh[n][tid + s];
-            }
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
-        }
-    }
+    reduce_result(temp, d_offset, first_row, num_rows, tid);
 }
 
 void main() {
index 82ec42d257d0c80cca0e40266ac46e491a71653e..86b0159d97a898bf6007136bea0bf4523c2e3e89 100644 (file)
@@ -5,11 +5,6 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-layout (constant_id = 1) const uint NUM_ROWS = 1;
-
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
 void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
@@ -33,10 +28,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     const uint q_offset = 32*v_im + l0;
     const uint y_offset = 128*v_im + l0;
 
-    FLOAT_TYPE temp[NUM_ROWS];
+    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
 
-    [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
-        temp[i] = FLOAT_TYPE(0);
+    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+            temp[j][i] = FLOAT_TYPE(0);
+        }
     }
 
     const uint s_shift = 4 * v_im;
@@ -44,15 +41,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y_idx = i * QUANT_K + y_offset;
 
-        B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
-        B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
-        B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
-        B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24];
-        B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32];
-        B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40];
-        B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
-        B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
-
         [[unroll]] for (uint n = 0; n < num_rows; ++n) {
             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
             const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
@@ -70,39 +58,34 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
             u8vec2 s8 = unpack8(s8_16);
             u8vec2 s10 = unpack8(s10_16);
 
-            FLOAT_TYPE sum = FLOAT_TYPE(0.0);
-            [[unroll]] for (int l = 0; l < 2; ++l) {
-                sum = fma(FLOAT_TYPE(b0[l])   * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4)),
-                      fma(FLOAT_TYPE(b32[l])  * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4)),
-                      fma(FLOAT_TYPE(b64[l])  * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4)),
-                      fma(FLOAT_TYPE(b96[l])  * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 3)) != 0) ? 0 : 4)),
-                      fma(FLOAT_TYPE(b16[l])  * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
-                      fma(FLOAT_TYPE(b48[l])  * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
-                      fma(FLOAT_TYPE(b80[l])  * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
-                      fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
+            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+
+                B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
+                B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
+                B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
+                B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
+                B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
+                B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
+                B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
+                B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
+
+                FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+                [[unroll]] for (int l = 0; l < 2; ++l) {
+                    sum = fma(FLOAT_TYPE(b0[l])   * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4)),
+                          fma(FLOAT_TYPE(b32[l])  * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4)),
+                          fma(FLOAT_TYPE(b64[l])  * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4)),
+                          fma(FLOAT_TYPE(b96[l])  * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 3)) != 0) ? 0 : 4)),
+                          fma(FLOAT_TYPE(b16[l])  * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
+                          fma(FLOAT_TYPE(b48[l])  * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
+                          fma(FLOAT_TYPE(b80[l])  * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
+                          fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
+                }
+                temp[j][n] = fma(d, sum, temp[j][n]);
             }
-            temp[n] = fma(d, sum, temp[n]);
         }
     }
 
-    // sum up partial sums and write back result
-    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        tmpsh[n][tid] = temp[n];
-    }
-    barrier();
-    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
-        if (tid < s) {
-            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-                tmpsh[n][tid] += tmpsh[n][tid + s];
-            }
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
-        }
-    }
+    reduce_result(temp, d_offset, first_row, num_rows, tid);
 }
 
 void main() {
index 677c207a842fc10b85bc4b57f0b1d858a68a42c5..cd1dd8e89c21e9b34fa56ea96619c820eaa7e0e2 100644 (file)
@@ -6,11 +6,6 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-layout (constant_id = 1) const uint NUM_ROWS = 1;
-
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
 void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
@@ -36,21 +31,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     const uint q_offset = 32*v_im + l0;
     const uint y_offset = 64*v_im + l0;
 
-    FLOAT_TYPE temp[NUM_ROWS];
+    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
 
-    [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
-        temp[i] = FLOAT_TYPE(0);
+    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+            temp[j][i] = FLOAT_TYPE(0);
+        }
     }
 
     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y1_idx = i * QUANT_K + y_offset;
         const uint y2_idx = y1_idx + 128;
 
-        B_TYPE_VEC4 by10 =  data_b_v4[(b_offset + y1_idx) / 4];
-        B_TYPE_VEC4 by132 = data_b_v4[(b_offset + y1_idx) / 4 + 8];
-        B_TYPE_VEC4 by20 =  data_b_v4[(b_offset + y2_idx) / 4];
-        B_TYPE_VEC4 by232 = data_b_v4[(b_offset + y2_idx) / 4 + 8];
-
         [[unroll]] for (uint n = 0; n < num_rows; ++n) {
             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
             f16vec2 d = data_a[ib0 + i].d;
@@ -103,37 +95,27 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
             const uint32_t q4_14 = qs64_hi4.z;
             const uint32_t q4_15 = qs64_hi4.w;
 
-            const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x),      q4_0,  fma(FLOAT_TYPE(by10.y),  q4_1,  fma(FLOAT_TYPE(by10.z),  q4_2,  FLOAT_TYPE(by10.w) *  q4_3)));
-            const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x),     q4_4,  fma(FLOAT_TYPE(by132.y), q4_5,  fma(FLOAT_TYPE(by132.z), q4_6,  FLOAT_TYPE(by132.w) * q4_7)));
-            const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x),      q4_8,  fma(FLOAT_TYPE(by20.y),  q4_9,  fma(FLOAT_TYPE(by20.z),  q4_10, FLOAT_TYPE(by20.w) *  q4_11)));
-            const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x),     q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
-            const FLOAT_TYPE smin =
-                fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
-                fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
-                fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
-                fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6,     FLOAT_TYPE(by232.w) * sc7)))))))))))))));
-            temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n]));
-        }
-    }
-
-    // sum up partial sums and write back result
-    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        tmpsh[n][tid] = temp[n];
-    }
-    barrier();
-    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
-        if (tid < s) {
-            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-                tmpsh[n][tid] += tmpsh[n][tid + s];
+            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+                B_TYPE_VEC4 by10 =  data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4];
+                B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8];
+                B_TYPE_VEC4 by20 =  data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4];
+                B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8];
+
+                const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x),      q4_0,  fma(FLOAT_TYPE(by10.y),  q4_1,  fma(FLOAT_TYPE(by10.z),  q4_2,  FLOAT_TYPE(by10.w) *  q4_3)));
+                const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x),     q4_4,  fma(FLOAT_TYPE(by132.y), q4_5,  fma(FLOAT_TYPE(by132.z), q4_6,  FLOAT_TYPE(by132.w) * q4_7)));
+                const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x),      q4_8,  fma(FLOAT_TYPE(by20.y),  q4_9,  fma(FLOAT_TYPE(by20.z),  q4_10, FLOAT_TYPE(by20.w) *  q4_11)));
+                const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x),     q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
+                const FLOAT_TYPE smin =
+                    fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
+                    fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
+                    fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
+                    fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6,     FLOAT_TYPE(by232.w) * sc7)))))))))))))));
+                temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
             }
         }
-        barrier();
-    }
-    if (tid == 0) {
-        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
-        }
     }
+
+    reduce_result(temp, d_offset, first_row, num_rows, tid);
 }
 
 void main() {
index ed3c25d891c79e39b5bf643c3358d316368fab56..0a68891c35a5f9fdac10b4f0a05a5d4bfc27c1cd 100644 (file)
@@ -6,11 +6,6 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-layout (constant_id = 1) const uint NUM_ROWS = 1;
-
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
 void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
@@ -33,25 +28,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     const uint q_offset = 32*v_im + l0;
     const uint y_offset = 64*v_im + l0;
 
-    FLOAT_TYPE temp[NUM_ROWS];
+    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
 
-    [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
-        temp[i] = FLOAT_TYPE(0);
+    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+            temp[j][i] = FLOAT_TYPE(0);
+        }
     }
 
     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y1_idx = i * QUANT_K + y_offset;
         const uint y2_idx = y1_idx + 128;
 
-        B_TYPE_VEC2 by10 =  data_b_v2[(b_offset + y1_idx) / 2];
-        B_TYPE_VEC2 by116 = data_b_v2[(b_offset + y1_idx) / 2 + 8];
-        B_TYPE_VEC2 by132 = data_b_v2[(b_offset + y1_idx) / 2 + 16];
-        B_TYPE_VEC2 by148 = data_b_v2[(b_offset + y1_idx) / 2 + 24];
-        B_TYPE_VEC2 by20 =  data_b_v2[(b_offset + y2_idx) / 2];
-        B_TYPE_VEC2 by216 = data_b_v2[(b_offset + y2_idx) / 2 + 8];
-        B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
-        B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
-
         [[unroll]] for (uint n = 0; n < num_rows; ++n) {
             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
             f16vec2 d = data_a[ib0 + i].d;
@@ -116,53 +104,47 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
             const uint32_t q4_14 = qs64_80_hi4.z;
             const uint32_t q4_15 = qs64_80_hi4.w;
 
-            const FLOAT_TYPE sx =
-              fma(FLOAT_TYPE(by10.x), q4_0,
-              fma(FLOAT_TYPE(by10.y), q4_1,
-              fma(FLOAT_TYPE(by116.x), q4_2,
-                 FLOAT_TYPE(by116.y) * q4_3)));
-            const FLOAT_TYPE sy =
-              fma(FLOAT_TYPE(by132.x), q4_4,
-              fma(FLOAT_TYPE(by132.y), q4_5,
-              fma(FLOAT_TYPE(by148.x), q4_6,
-                 FLOAT_TYPE(by148.y) * q4_7)));
-            const FLOAT_TYPE sz =
-              fma(FLOAT_TYPE(by20.x), q4_8,
-              fma(FLOAT_TYPE(by20.y), q4_9,
-              fma(FLOAT_TYPE(by216.x), q4_10,
-                 FLOAT_TYPE(by216.y) * q4_11)));
-            const FLOAT_TYPE sw =
-              fma(FLOAT_TYPE(by232.x), q4_12,
-              fma(FLOAT_TYPE(by232.y), q4_13,
-              fma(FLOAT_TYPE(by248.x), q4_14,
-                 FLOAT_TYPE(by248.y) * q4_15)));
-            const FLOAT_TYPE smin =
-              fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
-              fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
-              fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
-                  (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
-            temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n]));
-        }
-    }
-
-    // sum up partial sums and write back result
-    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        tmpsh[n][tid] = temp[n];
-    }
-    barrier();
-    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
-        if (tid < s) {
-            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-                tmpsh[n][tid] += tmpsh[n][tid + s];
+            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+                B_TYPE_VEC2 by10 =  data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2];
+                B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8];
+                B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16];
+                B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24];
+                B_TYPE_VEC2 by20 =  data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2];
+                B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8];
+                B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16];
+                B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24];
+
+                const FLOAT_TYPE sx =
+                  fma(FLOAT_TYPE(by10.x), q4_0,
+                  fma(FLOAT_TYPE(by10.y), q4_1,
+                  fma(FLOAT_TYPE(by116.x), q4_2,
+                     FLOAT_TYPE(by116.y) * q4_3)));
+                const FLOAT_TYPE sy =
+                  fma(FLOAT_TYPE(by132.x), q4_4,
+                  fma(FLOAT_TYPE(by132.y), q4_5,
+                  fma(FLOAT_TYPE(by148.x), q4_6,
+                     FLOAT_TYPE(by148.y) * q4_7)));
+                const FLOAT_TYPE sz =
+                  fma(FLOAT_TYPE(by20.x), q4_8,
+                  fma(FLOAT_TYPE(by20.y), q4_9,
+                  fma(FLOAT_TYPE(by216.x), q4_10,
+                     FLOAT_TYPE(by216.y) * q4_11)));
+                const FLOAT_TYPE sw =
+                  fma(FLOAT_TYPE(by232.x), q4_12,
+                  fma(FLOAT_TYPE(by232.y), q4_13,
+                  fma(FLOAT_TYPE(by248.x), q4_14,
+                     FLOAT_TYPE(by248.y) * q4_15)));
+                const FLOAT_TYPE smin =
+                  fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
+                  fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
+                  fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
+                      (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
+                temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
             }
         }
-        barrier();
-    }
-    if (tid == 0) {
-        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
-        }
     }
+
+    reduce_result(temp, d_offset, first_row, num_rows, tid);
 }
 
 void main() {
index fab4ff5ff054e51cfb619f9e9cbe878d1d73406a..70e13a56bd73050f6bbe5864410ddf0ae96a7d40 100644 (file)
@@ -6,11 +6,6 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-layout (constant_id = 1) const uint NUM_ROWS = 1;
-
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
 void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     uint a_offset, b_offset, d_offset;
     get_offsets(a_offset, b_offset, d_offset);
@@ -36,20 +31,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     const uint s_offset  =  8*v_im + is;
     const uint y_offset = 128*v_im + l0;
 
-    FLOAT_TYPE temp[NUM_ROWS];
+    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
 
-    [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
-        temp[i] = FLOAT_TYPE(0);
+    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+            temp[j][i] = FLOAT_TYPE(0);
+        }
     }
 
     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
         const uint y_idx = i * QUANT_K + y_offset;
 
-        B_TYPE_VEC4 by0  = data_b_v4[(b_offset + y_idx) / 4];
-        B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8];
-        B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16];
-        B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24];
-
         [[unroll]] for (uint n = 0; n < num_rows; ++n) {
             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
             const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
@@ -84,35 +76,25 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
             uvec4 q2 = uvec4(unpack8(q2_u32));
             uvec4 q3 = uvec4(unpack8(q3_u32));
 
-            FLOAT_TYPE sum = FLOAT_TYPE(0.0);
-            [[unroll]] for (int l = 0; l < 4; ++l) {
-                sum = fma(FLOAT_TYPE(by0[l])  * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
-                      fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
-                      fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
-                      fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
+            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+                B_TYPE_VEC4 by0  = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
+                B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
+                B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
+                B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
+
+                FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+                [[unroll]] for (int l = 0; l < 4; ++l) {
+                    sum = fma(FLOAT_TYPE(by0[l])  * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
+                          fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
+                          fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
+                          fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
+                }
+                temp[j][n] += sum * d;
             }
-            temp[n] += sum * d;
         }
     }
 
-    // sum up partial sums and write back result
-    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        tmpsh[n][tid] = temp[n];
-    }
-    barrier();
-    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
-        if (tid < s) {
-            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-                tmpsh[n][tid] += tmpsh[n][tid + s];
-            }
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
-        }
-    }
+    reduce_result(temp, d_offset, first_row, num_rows, tid);
 }
 
 void main() {
index c79acffd20eaa49f02d8fe08e32060c4a4d2be54..1e892f66365e0f7318320cac756603ce532233df 100644 (file)
@@ -3937,7 +3937,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
 
-    for (int bs : {1, 512}) {
+    for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
         for (ggml_type type_a : all_types) {
             for (ggml_type type_b : {GGML_TYPE_F32}) {
                 test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, bs, 14336, {1,  1}, {1, 1}));