]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: im2col and matmul optimizations for stable diffusion (#10942)
authorJeff Bolz <redacted>
Sun, 29 Dec 2024 09:16:34 +0000 (03:16 -0600)
committerGitHub <redacted>
Sun, 29 Dec 2024 09:16:34 +0000 (10:16 +0100)
* tests: Add im2col perf tests

* vulkan: optimize im2col, more elements per thread

* vulkan: increase small tile size for NV_coopmat2

* vulkan: change im2col to 512 elements per workgroup

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
tests/test-backend-ops.cpp

index 6dfc60c9b5968b51841bb549294b1ec45193bf95..8e47e79ae328451b89f1d090768937f4a883f6f9 100644 (file)
@@ -1404,10 +1404,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
         // spec constants and tile sizes for non-quant matmul/matmul_id
         l_warptile = { 256, 128, 256, 64 };
         m_warptile = { 256, 128, 128, 64 };
-        s_warptile = { 128,  32,  16, 64 };
+        s_warptile = { 128,  64,  64, 64 };
         l_wg_denoms = {128, 256, 1 };
         m_wg_denoms = {128, 128, 1 };
-        s_wg_denoms = { 32,  16, 1 };
+        s_wg_denoms = { 64,  64, 1 };
 
         // spec constants and tile sizes for quant matmul (non-Qi_K)
         l_warptile_mmq = { 256, 128, 256, 64 };
@@ -2017,11 +2017,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
     if (device->float_controls_rte_fp16) {
-        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
     } else {
-        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
     }
 
     ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
index 966fedf8fa7a0f1e6adf8560c0714247c9572db8..122b1e93fb496d8a717c3f03161b4af695e9b8d4 100644 (file)
@@ -2,6 +2,7 @@
 
 #extension GL_EXT_shader_16bit_storage : require
 #extension GL_EXT_spirv_intrinsics: enable
+#extension GL_EXT_control_flow_attributes : require
 
 #if RTE16
 spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
@@ -23,40 +24,64 @@ layout (push_constant) uniform parameter
 
 #include "types.comp"
 
-#define BLOCK_SIZE 256
+layout(constant_id = 0) const uint BLOCK_SIZE = 32;
 
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+const uint NUM_ITER = 512 / BLOCK_SIZE;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
 layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 
 void main() {
-    const uint i = gl_GlobalInvocationID.x;
-    if (i >= p.pelements) {
-        return;
-    }
-
-    const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
-    const uint kx = i / ksize;
-    const uint kd = kx * ksize;
-    const uint ky = (i - kd) / p.OW;
-    const uint ix = i % p.OW;
+    const uint gidx = gl_GlobalInvocationID.x;
 
     const uint oh = gl_GlobalInvocationID.y;
     const uint batch = gl_GlobalInvocationID.z / p.IC;
     const uint ic = gl_GlobalInvocationID.z % p.IC;
 
-    const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
-    const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
+    A_TYPE values[NUM_ITER];
+    uint offset_dst[NUM_ITER];
+    [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
+        values[idx] = A_TYPE(0);
+    }
+
+    [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
+
+        const uint i = gidx * NUM_ITER + idx;
+
+        const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
+        const uint kx = i / ksize;
+        const uint kd = kx * ksize;
+        const uint ky = (i - kd) / p.OW;
+        const uint ix = i % p.OW;
+
+        const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
+        const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
 
-    const uint offset_dst =
-        ((batch * p.OH + oh) * p.OW + ix) * p.CHW +
-        (ic * (p.KW * p.KH) + ky * p.KW + kx);
+        offset_dst[idx] =
+            ((batch * p.OH + oh) * p.OW + ix) * p.CHW +
+            (ic * (p.KW * p.KH) + ky * p.KW + kx);
 
-    if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) {
-        data_d[offset_dst] = D_TYPE(0.0f);
-    } else {
-        const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
-        data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]);
+        if (i >= p.pelements) {
+            continue;
+        }
+
+        if (iih < p.IH && iiw < p.IW) {
+            const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
+            values[idx] = data_a[offset_src + iih * p.IW + iiw];
+        }
     }
+
+    [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
+
+        const uint i = gidx * NUM_ITER + idx;
+
+        if (i >= p.pelements) {
+            continue;
+        }
+
+        data_d[offset_dst[idx]] = D_TYPE(values[idx]);
+    }
+
 }
index ccdd3fb57a5041452fa0cc7d275a4e45fe95f5e7..c79acffd20eaa49f02d8fe08e32060c4a4d2be54 100644 (file)
@@ -3945,6 +3945,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
         }
     }
 
+    for (int K : {3, 5}) {
+        for (int IC : {256, 2560}) {
+            for (int IW_IH : {32, 64, 256}) {
+                if (IC == 2560 && IW_IH == 256) {
+                    // too big
+                    continue;
+                }
+                test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {IW_IH, IW_IH, IC, 1}, {K, K, IC, 1}, 1, 1, 1, 1, 1, 1, true));
+            }
+        }
+    }
+
     return test_cases;
 }