ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
+ if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
+ } else {
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
+ }
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
#version 450
#extension GL_EXT_control_flow_attributes : require
+#if USE_SUBGROUP_ADD
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#endif
#include "types.glsl"
}
barrier();
- for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
- [[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
- const uint k = (tid % (w >> 1)) +
- (D_STATE * (tid / (w >> 1))) +
- j * D_STATE * (D_STATE / (w >> 1));
- if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
- stateC[k] += stateC[k + (w >> 1)];
+ [[unroll]]
+ for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
+ [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
+ const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
+ if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
+ stateC[k] += stateC[k + w];
}
}
barrier();
}
- [[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
+ [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
const uint idx = (tid % SUBGROUP_SIZE) +
D_STATE * (tid / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
+ const uint max_idx = SUBGROUP_SIZE - 1 +
+ D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
+ j * D_STATE * (D_STATE / SUBGROUP_SIZE);
- uint lane = tid % SUBGROUP_SIZE;
-
- [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
- if (idx + offset < SPLIT_H * D_STATE) {
- stateC[idx] += stateC[idx + offset];
+ if (idx < SPLIT_H * D_STATE ||
+ max_idx < SPLIT_H * D_STATE) {
+ float sc;
+#if USE_SUBGROUP_ADD
+ sc = stateC[idx];
+ sc = subgroupAdd(sc);
+#else
+ [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
+ if (idx + offset < SPLIT_H * D_STATE) {
+ stateC[idx] += stateC[idx + offset];
+ }
+ barrier();
}
- barrier();
- }
+ if (tid % SUBGROUP_SIZE == 0) {
+ sc = stateC[idx];
+ }
+#endif
- if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
- const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
- d[y_base_idx + i * stride_y + k] = stateC[idx];
+ if (tid % SUBGROUP_SIZE == 0) {
+ const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
+ d[y_base_idx + i * stride_y + k] = sc;
+ }
}
}
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
- string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
+ string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
+ string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});