]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: optimize and reenable split_k (llama/10637)
authorJeff Bolz <redacted>
Tue, 3 Dec 2024 19:29:54 +0000 (13:29 -0600)
committerGeorgi Gerganov <redacted>
Sun, 8 Dec 2024 18:14:35 +0000 (20:14 +0200)
Use vector loads when possible in mul_mat_split_k_reduce. Use split_k
when there aren't enough workgroups to fill the shaders.

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp

index df6a659f44ff5b17bdf188fce1758073effbbf43..17e1be105f8b11a011d9f2510d92891bc4846e0e 100644 (file)
@@ -165,6 +165,7 @@ struct vk_device_struct {
     vk_queue transfer_queue;
     bool single_queue;
     uint32_t subgroup_size;
+    uint32_t shader_core_count;
     bool uma;
 
     size_t idx;
@@ -1498,7 +1499,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
@@ -1610,11 +1611,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
         const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
 
         bool maintenance4_support = false;
+        bool sm_builtins = false;
 
         // Check if maintenance4 is supported
         for (const auto& properties : ext_props) {
             if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
                 maintenance4_support = true;
+            } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
+                sm_builtins = true;
             }
         }
 
@@ -1622,11 +1626,21 @@ static vk_device ggml_vk_get_device(size_t idx) {
         vk::PhysicalDeviceMaintenance3Properties props3;
         vk::PhysicalDeviceMaintenance4Properties props4;
         vk::PhysicalDeviceSubgroupProperties subgroup_props;
+        vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
         props2.pNext = &props3;
         props3.pNext = &subgroup_props;
+
+        VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&subgroup_props;
+
         if (maintenance4_support) {
-            subgroup_props.pNext = &props4;
+            last_struct->pNext = (VkBaseOutStructure *)&props4;
+            last_struct = (VkBaseOutStructure *)&props4;
+        }
+        if (sm_builtins) {
+            last_struct->pNext = (VkBaseOutStructure *)&sm_props;
+            last_struct = (VkBaseOutStructure *)&sm_props;
         }
+
         device->physical_device.getProperties2(&props2);
         device->properties = props2.properties;
 
@@ -1643,6 +1657,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
         device->vendor_id = device->properties.vendorID;
         device->subgroup_size = subgroup_props.subgroupSize;
         device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
+        if (sm_builtins) {
+            device->shader_core_count = sm_props.shaderSMCount;
+        } else {
+            device->shader_core_count = 0;
+        }
 
         bool fp16_storage = false;
         bool fp16_compute = false;
@@ -2732,15 +2751,25 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
     dst->device->device.resetFences({ dst->device->fence });
 }
 
-static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
+static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
     VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
-    // if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
-    //     return 4;
-    // }
 
-    return 1;
+    uint32_t split_k = 1;
+    if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
+        // If k is 'large' and the SMs will fill less than halfway, use split_k.
+        uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
+        uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
+        if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
+            split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
+            // Clamp to 2 or 4
+            split_k = std::min(split_k, 4u);
+            if (split_k == 3) {
+                split_k = 2;
+            }
+        }
+    }
 
-    GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k);
+    return split_k;
 }
 
 static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
@@ -2964,10 +2993,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
     const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
 
-    const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
-
     vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
 
+    const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
+
     const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
     const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
@@ -2993,7 +3022,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     if (dryrun) {
         const uint64_t x_sz_upd = x_sz * ne02 * ne03;
         const uint64_t y_sz_upd = y_sz * ne12 * ne13;
-        const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0;
+        const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
         if (
                 (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
                 (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
index 825b91031f569f362d7158bbf028aae74dfa4997..4c64fd47af718d5e43224648e2f5a7b527739723 100644 (file)
@@ -5,7 +5,9 @@
 layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer A {float data_a[];};
+layout (binding = 0) readonly buffer A4 {vec4 data_a4[];};
 layout (binding = 1) writeonly buffer D {float data_d[];};
+layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];};
 
 layout (push_constant) uniform parameter {
     uint ne;
@@ -13,17 +15,34 @@ layout (push_constant) uniform parameter {
 } p;
 
 void main() {
-    const uint idx = gl_GlobalInvocationID.x;
+    // Each invocation handles four consecutive components
+    const uint idx = gl_GlobalInvocationID.x * 4;
 
     if (idx >= p.ne) {
         return;
     }
 
-    float result = 0.0f;
+    // Check if all four components are in bounds and aligned,
+    // then use vector loads
+    if (idx + 3 < p.ne && (p.ne % 4) == 0) {
+        vec4 result = vec4(0.0f);
 
-    [[unroll]] for (uint i = 0; i < p.k_num; i++) {
-        result += data_a[i * p.ne + idx];
-    }
+        [[unroll]] for (uint i = 0; i < p.k_num; i++) {
+            result += data_a4[(i * p.ne + idx) / 4];
+        }
+
+        data_d4[idx / 4] = result;
+    } else {
+        [[unroll]] for (uint j = 0; j < 4; ++j) {
+            if (idx + j < p.ne) {
+                float result = 0.0f;
 
-    data_d[idx] = result;
+                [[unroll]] for (uint i = 0; i < p.k_num; i++) {
+                    result += data_a[i * p.ne + idx + j];
+                }
+
+                data_d[idx + j] = result;
+            }
+        }
+    }
 }