]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Optimize binary ops (#10270)
authorJeff Bolz <redacted>
Thu, 14 Nov 2024 05:22:55 +0000 (23:22 -0600)
committerGitHub <redacted>
Thu, 14 Nov 2024 05:22:55 +0000 (06:22 +0100)
Reuse the index calculations across all of src0/src1/dst. Add a shader
variant for when src0/src1 are the same dimensions and additional modulus
for src1 aren't needed. Div/mod are slow, so add "fast" div/mod that
have a fast path when the calculation isn't needed or can be done more
cheaply.

ggml/src/ggml-vulkan.cpp
ggml/src/vulkan-shaders/acc.comp
ggml/src/vulkan-shaders/add.comp
ggml/src/vulkan-shaders/concat.comp
ggml/src/vulkan-shaders/div.comp
ggml/src/vulkan-shaders/generic_binary_head.comp
ggml/src/vulkan-shaders/get_rows.comp
ggml/src/vulkan-shaders/get_rows_quant.comp
ggml/src/vulkan-shaders/mul.comp

index 361fdcbc159a973993c764cffa6c3c0045e65cfe..c02c356657a589de342aa609041c670bed074ebd 100644 (file)
@@ -192,9 +192,10 @@ struct vk_device_struct {
     vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
     vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
     vk_pipeline pipeline_acc_f32;
-    vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
-    vk_pipeline pipeline_mul_f32;
-    vk_pipeline pipeline_div_f32;
+    vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
+    vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
+    vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
+    vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
     vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
     vk_pipeline pipeline_upscale_f32;
     vk_pipeline pipeline_scale_f32;
@@ -1456,13 +1457,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -3801,20 +3806,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         return nullptr;
     case GGML_OP_ADD:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_add_f32;
+            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32;
         }
         if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
-            return ctx->device->pipeline_add_f16_f32_f16;
+            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
         }
         return nullptr;
     case GGML_OP_MUL:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_mul_f32;
+            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
         }
         return nullptr;
     case GGML_OP_DIV:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_div_f32;
+            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
         }
         return nullptr;
     case GGML_OP_CONCAT:
index 4c8739efee22718bb58b2798b1390cdb271fe75f..4f5a04e71ce11a3d84f7feffc49cd5618d139834 100644 (file)
@@ -3,6 +3,8 @@
 #include "types.comp"
 #include "generic_binary_head.comp"
 
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
     const uint idx = gl_GlobalInvocationID.x;
     if (idx >= p.ne) {
@@ -15,10 +17,13 @@ void main() {
     const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
     const uint ox = src1_i % p.nb01;
 
+    uint i00, i01, i02, i03;
+    get_indices(idx, i00, i01, i02, i03);
+
     if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
-        data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
+        data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
     } else {
-        data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]));
+        data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]));
     }
 }
 
index 3974845d637ab579e4e0a47e061192adf31e5fbe..da61b76dfc8a38a586763cb0d4d8f4ad477ac38f 100644 (file)
@@ -1,14 +1,29 @@
 #version 450
 
+#extension GL_EXT_shader_16bit_storage : require
+
 #include "types.comp"
 #include "generic_binary_head.comp"
 
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
-    const uint idx = get_idx();
+    uint idx = get_idx();
 
-    if (idx >= p.ne) {
-        return;
-    }
+    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+    const uint num_iter = 2;
 
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)]));
+    [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+        if (idx >= p.ne) {
+            continue;
+        }
+        uint i00, i01, i02, i03;
+        get_indices(idx, i00, i01, i02, i03);
+
+        data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
+
+        idx += num_threads;
+    }
 }
index c23b6eb1b0cd50ba29b03cdd7dadff746c7e4321..683f9ac3c1d633c49a2243224ec27e1b5e225368 100644 (file)
@@ -3,6 +3,8 @@
 #include "types.comp"
 #include "generic_binary_head.comp"
 
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
     const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
     const int dim = p.param3;
index 8cfce58b150165039917cb474fddb5416efcceff..e581905b3f049a38f251f35f53e5777525d1f342 100644 (file)
@@ -3,12 +3,25 @@
 #include "types.comp"
 #include "generic_binary_head.comp"
 
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
-    const uint idx = get_idx();
+    uint idx = get_idx();
 
-    if (idx >= p.ne) {
-        return;
-    }
+    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+    const uint num_iter = 2;
 
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)]));
+    [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+        if (idx >= p.ne) {
+            continue;
+        }
+        uint i00, i01, i02, i03;
+        get_indices(idx, i00, i01, i02, i03);
+
+        data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
+
+        idx += num_threads;
+    }
 }
index b6beaff1cf65adf0b9ddef08fe3fe075aca88c41..a6555fa27063ece3f9cc7425b77483b71352c873 100644 (file)
@@ -1,4 +1,5 @@
 #extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_control_flow_attributes : require
 
 layout (push_constant) uniform parameter
 {
@@ -10,43 +11,50 @@ layout (push_constant) uniform parameter
     float param1; float param2; int param3;
 } p;
 
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
 layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
 layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 
+// true if src0/src1 are the same shape and the indices can be reused without additional modulus
+layout(constant_id = 0) const bool norepeat = false;
+
 uint get_idx() {
     return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
 }
 
-uint src0_idx(uint idx) {
-    const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
-    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
-    const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
-    const uint i02_offset = i02*p.ne01*p.ne00;
-    const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
-    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
-    return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
+uint fastmod(uint a, uint b) {
+    if ((b & (b-1)) == 0) {
+        return a & (b-1);
+    }
+    return a % b;
 }
 
-uint src1_idx(uint idx) {
-    const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
+uint fastdiv(uint a, uint b) {
+    return (a < b) ? 0 : (a / b);
+}
+
+void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
+    i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00));
     const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
-    const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
+    i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
     const uint i02_offset = i02*p.ne01*p.ne00;
-    const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
-    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+    i01 = (idx - i03_offset - i02_offset) / p.ne00;
+    i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+}
+
+uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
+    return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+}
 
-    return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10;
+uint src1_idx(uint i00, uint i01, uint i02, uint i03) {
+    if (norepeat) {
+        return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10;
+    } else {
+        return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10;
+    }
 }
 
-uint dst_idx(uint idx) {
-    const uint i23 = idx / (p.ne22*p.ne21*p.ne20);
-    const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
-    const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20);
-    const uint i22_offset = i22*p.ne21*p.ne20;
-    const uint i21 = (idx - i23_offset - i22_offset) / p.ne20;
-    const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20;
-    return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20;
+uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
+    return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20;
 }
index e9ff22efa9c7ae74baf6b11a1f5c7005097b7162..a7b81e52ce55077010c0c4f5c49b51d881e39a8a 100644 (file)
@@ -3,6 +3,8 @@
 #include "types.comp"
 #include "generic_binary_head.comp"
 
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
     const uint i00 = gl_GlobalInvocationID.x;
     const uint i10 = gl_GlobalInvocationID.y;
index 53a9a96f2360afe5f0cf3cf793cac45682538cfe..8d30b63c165b2effbf85bc6deb3c40957b5a713a 100644 (file)
@@ -4,6 +4,8 @@
 #include "generic_binary_head.comp"
 #include "dequant_funcs.comp"
 
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
     const uint i00 = (gl_GlobalInvocationID.x)*2;
     const uint i10 = gl_GlobalInvocationID.y;
index bfb61c92d688e9d7d1e58538832ffa302e17610d..5ce57cbcfc7fbce4246253de9f0cea1109cf32bf 100644 (file)
@@ -3,12 +3,25 @@
 #include "types.comp"
 #include "generic_binary_head.comp"
 
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
-    const uint idx = get_idx();
+    uint idx = get_idx();
 
-    if (idx >= p.ne) {
-        return;
-    }
+    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+    const uint num_iter = 2;
 
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)]));
+    [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+        if (idx >= p.ne) {
+            continue;
+        }
+        uint i00, i01, i02, i03;
+        get_indices(idx, i00, i01, i02, i03);
+
+        data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
+
+        idx += num_threads;
+    }
 }