ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
}
- ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
#include "types.glsl"
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
+layout(constant_id = 1) const uint TOKENS_PER_WG = 16;
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer Src0 { float src0[]; };
layout(binding = 1) readonly buffer Src1 { float src1[]; };
};
void main() {
- const uint global_thread_id = gl_GlobalInvocationID.x;
- const uint i2 = gl_WorkGroupID.y;
+ const uint i1 = gl_GlobalInvocationID.x;
+ const uint i2 = gl_WorkGroupID.y * TOKENS_PER_WG + gl_LocalInvocationID.y;
const uint i3 = gl_WorkGroupID.z;
- if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
+ if (i1 >= nr || i2 >= n_t || i3 >= n_s) {
return;
}
- const uint i1 = global_thread_id;
const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
const uint src1_base = i1 * (nb11 / 4);
- const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
float sum = 0.0;
- [[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
- const uint src0_idx = src0_base + i0;
- const uint src1_idx = src1_base + i0;
- sum += src0[src0_idx] * src1[src1_idx];
+
+ if (nc == 4) {
+ sum = dot(
+ vec4(src0[src0_base], src0[src0_base + 1], src0[src0_base + 2], src0[src0_base + 3]),
+ vec4(src1[src1_base], src1[src1_base + 1], src1[src1_base + 2], src1[src1_base + 3])
+ );
+ } else {
+ [[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
+ sum += src0[src0_base + i0] * src1[src1_base + i0];
+ }
}
+ const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
dst[dst_idx] = sum;
}