bool uma;
bool prefer_host_memory;
bool float_controls_rte_fp16;
+ bool subgroup_basic;
bool subgroup_arithmetic;
bool subgroup_shuffle;
bool subgroup_ballot;
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
- ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
- ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
} else {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
}
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
+ device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic);
device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
#ifdef __APPLE__
std::array<uint32_t, 3> elements;
- const int splitH = 16;
- const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
+ const uint32_t d_state = src0->ne[0];
+ uint32_t num_subgroups = d_state / ctx->device->subgroup_size;
+ const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups);
const uint32_t num_workgroups_y = n_seq;
elements = { num_workgroups_x, num_workgroups_y, 1 };
return false;
}
- const uint32_t SPLIT_H = 16;
+ size_t shmem_size = d_state * sizeof(float);
- size_t stateC_size = SPLIT_H * d_state * sizeof(float);
+ if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) {
+ return false;
+ }
- if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
+ if (!device->subgroup_basic) {
return false;
}
#version 450
#extension GL_EXT_control_flow_attributes : require
+#extension GL_KHR_shader_subgroup_basic : enable
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
layout(constant_id = 0) const uint D_STATE = 128;
layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
-layout(constant_id = 2) const uint SPLIT_H = 16;
+
+const uint32_t c_factor = D_STATE / SUBGROUP_SIZE;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
}
}
-shared float stateC[SPLIT_H * D_STATE];
+#if !USE_SUBGROUP_ADD
+shared float temp[D_STATE];
+#endif
void main() {
- const uint tid = gl_LocalInvocationID.x;
- const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
- const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
- const uint seq_idx = gl_WorkGroupID.y;
+ const uint subgroup = gl_SubgroupID;
+ const uint lane = gl_SubgroupInvocationID;
+ const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane;
+ const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup;
+
+ const uint head_idx = subgroup_idx / d_head;
+ const uint head_off = (subgroup_idx % d_head) * 4;
+ const uint seq_idx = gl_WorkGroupID.y;
const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
- const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
+ const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4;
const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
const uint A_base_idx = (head_idx * nb31) / 4;
const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
- const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
+ const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx;
const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint stride_x = nb12 / 4;
const uint stride_C = nb52 / 4;
const uint stride_y = n_head * d_head;
- float state[SPLIT_H];
- [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
- state[j] = s0[s0_base_idx + j * D_STATE + tid];
- }
+ float state[c_factor];
- for (uint i = 0; i < n_tok; i++) {
- const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
+ [[unroll]] for (uint j = 0; j < c_factor; j++) {
+ state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane];
+ }
- const float dA = exp(dt_soft_plus * A[A_base_idx]);
+ float a = A[A_base_idx];
- const float B_val = B[B_base_idx + i * stride_B + tid];
- const float C_val = C[C_base_idx + i * stride_C + tid];
+ for (uint i = 0; i < n_tok; i++) {
+ float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
- [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
- const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
+ float state_sum = 0.0f;
+ const float dA = exp(dt_soft_plus * a);
+ const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus;
+ [[unroll]] for (uint j = 0; j < c_factor; j++) {
+ float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane];
+ float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane];
state[j] = (state[j] * dA) + (B_val * x_dt);
-
- stateC[j * D_STATE + tid] = state[j] * C_val;
+ state_sum += state[j] * C_val;
}
+#if USE_SUBGROUP_ADD
+ state_sum = subgroupAdd(state_sum);
+#else
+ temp[tid] = state_sum;
barrier();
- [[unroll]]
- for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
- [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
- const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
- if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
- stateC[k] += stateC[k + w];
- }
+ [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) {
+ if (lane < s) {
+ temp[tid] += temp[tid + s];
}
barrier();
}
-
- [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
- const uint idx = (tid % SUBGROUP_SIZE) +
- D_STATE * (tid / SUBGROUP_SIZE) +
- j * D_STATE * (D_STATE / SUBGROUP_SIZE);
- const uint max_idx = SUBGROUP_SIZE - 1 +
- D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
- j * D_STATE * (D_STATE / SUBGROUP_SIZE);
-
- if (idx < SPLIT_H * D_STATE ||
- max_idx < SPLIT_H * D_STATE) {
- float sc;
-#if USE_SUBGROUP_ADD
- sc = stateC[idx];
- sc = subgroupAdd(sc);
-#else
- [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
- if (idx + offset < SPLIT_H * D_STATE) {
- stateC[idx] += stateC[idx + offset];
- }
- barrier();
- }
- if (tid % SUBGROUP_SIZE == 0) {
- sc = stateC[idx];
- }
+ // get the value from lane 0
+ state_sum = temp[subgroup * SUBGROUP_SIZE];
+ barrier();
#endif
- if (tid % SUBGROUP_SIZE == 0) {
- const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
- d[y_base_idx + i * stride_y + k] = sc;
- }
- }
+ if (lane == 0) {
+ d[y_base_idx + i * stride_y] = state_sum;
}
-
- barrier();
}
- [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
- d[s_base_idx + j * D_STATE + tid] = state[j];
+ // write back the state
+ [[unroll]]
+ for (int j = 0; j < c_factor; j++) {
+ d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j];
}
}