vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
+ vk_pipeline pipeline_ssm_scan_f32_d128;
+ vk_pipeline pipeline_ssm_scan_f32_d256;
+ vk_pipeline pipeline_ssm_conv_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_opt_step_sgd_f32;
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
uint32_t C;
uint32_t H;
};
+struct vk_op_ssm_scan_push_constants {
+ uint32_t nb02, nb03, nb12, nb13;
+ uint32_t nb21, nb22, nb31;
+ uint32_t nb42, nb43, nb52, nb53;
+ uint32_t s_off;
+ uint32_t n_head, d_head, n_group, n_tok;
+};
+struct vk_op_ssm_conv_push_constants {
+ uint32_t nb01, nb02;
+ uint32_t nb11;
+ uint32_t dst_nb0, dst_nb1, dst_nb2;
+ uint32_t nc, ncs, nr, n_t, n_s;
+};
struct vk_op_conv2d_push_constants {
uint32_t Cout;
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
+
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
+ case GGML_OP_SSM_SCAN:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ const uint32_t d_state = src0->ne[0];
+ if (d_state == 128) {
+ return ctx->device->pipeline_ssm_scan_f32_d128;
+ } else if (d_state == 256) {
+ return ctx->device->pipeline_ssm_scan_f32_d256;
+ }
+ }
+ return nullptr;
+ case GGML_OP_SSM_CONV:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_ssm_conv_f32;
+ }
+ return nullptr;
case GGML_OP_OPT_STEP_ADAMW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_opt_step_adamw_f32;
}
}
break;
+ case GGML_OP_SSM_CONV:
+ {
+ const uint32_t nr = src0->ne[1];
+ const uint32_t n_t = dst->ne[1];
+ const uint32_t n_s = dst->ne[2];
+ elements = { nr, n_t, n_s };
+ }
+ break;
default:
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
break;
);
}
+static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+ const ggml_tensor * src3 = dst->src[3];
+ const ggml_tensor * src4 = dst->src[4];
+ const ggml_tensor * src5 = dst->src[5];
+
+ GGML_ASSERT(dst->buffer != nullptr);
+
+ const uint32_t head_dim = src0->ne[1];
+ const uint32_t n_head = src1->ne[1];
+ const uint32_t n_group = src4->ne[1];
+ const uint32_t n_tok = src1->ne[2];
+ const uint32_t n_seq = src1->ne[3];
+
+ bool is_mamba2 = (src3->nb[1] == sizeof(float));
+ GGML_ASSERT(is_mamba2);
+
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);
+ GGML_ASSERT(pipeline != nullptr);
+
+ if (dryrun) {
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
+ return;
+ }
+
+ const int64_t s_off = ggml_nelements(src1) * sizeof(float);
+
+ const vk_op_ssm_scan_push_constants pc = {
+ (uint32_t)src0->nb[2], (uint32_t)src0->nb[3],
+ (uint32_t)src1->nb[2], (uint32_t)src1->nb[3],
+ (uint32_t)src2->nb[1], (uint32_t)src2->nb[2],
+ (uint32_t)src3->nb[1],
+ (uint32_t)src4->nb[2], (uint32_t)src4->nb[3],
+ (uint32_t)src5->nb[2], (uint32_t)src5->nb[3],
+ (uint32_t)s_off,
+ n_head, head_dim, n_group, n_tok
+ };
+
+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
+ ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC];
+ for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
+ src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
+ }
+
+ vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr };
+ size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 };
+ bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false };
+
+ if (ctx->device->uma) {
+ for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
+ ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
+ srcs_uma[i] = d_srcs[i] != nullptr;
+ }
+ ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
+ dst_uma = d_D != nullptr;
+ }
+
+ if (!dst_uma) {
+ d_D = dst_buf_ctx->dev_buffer;
+ dst_offset = vk_tensor_offset(dst) + dst->view_offs;
+ }
+ for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
+ if (!srcs_uma[i]) {
+ d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
+ src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
+ }
+ }
+
+ size_t dst_size = ggml_nbytes(dst);
+ size_t src_sizes[GGML_MAX_SRC];
+ for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
+ src_sizes[i] = ggml_nbytes(dst->src[i]);
+ }
+
+ std::array<uint32_t, 3> elements;
+
+ const int splitH = 16;
+ const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
+ const uint32_t num_workgroups_y = n_seq;
+ elements = { num_workgroups_x, num_workgroups_y, 1 };
+
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
+ vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
+ vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
+ vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
+ vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
+ vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
+ vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
+ vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
+ vk_subbuffer{ d_D, dst_offset, dst_size }
+ }, pc, elements);
+}
+
+static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, {
+ (uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
+ (uint32_t)src1->nb[1],
+ (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
+ (uint32_t)src1->ne[0],
+ (uint32_t)src0->ne[0],
+ (uint32_t)src0->ne[1],
+ (uint32_t)dst->ne[1],
+ (uint32_t)dst->ne[2],
+ }, dryrun);
+}
+
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
const ggml_tensor * x = dst->src[0];
const ggml_tensor * g = dst->src[1];
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
+ case GGML_OP_SSM_SCAN:
+ case GGML_OP_SSM_CONV:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW:
break;
+ case GGML_OP_SSM_SCAN:
+ ggml_vk_ssm_scan(ctx, compute_ctx, node, dryrun);
+
+ break;
+
+ case GGML_OP_SSM_CONV:
+ ggml_vk_ssm_conv(ctx, compute_ctx, node, dryrun);
+
+ break;
+
case GGML_OP_OPT_STEP_ADAMW:
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
+ case GGML_OP_SSM_SCAN:
+ case GGML_OP_SSM_CONV:
case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true;
+ case GGML_OP_SSM_SCAN:
+ {
+ for (int i = 0; i < 6; i++) {
+ if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
+ return false;
+ }
+ }
+ if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
+ return false;
+ }
+ if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ const uint32_t d_state = op->src[0]->ne[0];
+ const uint32_t head_dim = op->src[0]->ne[1];
+
+ bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
+ if (!is_mamba2) {
+ return false;
+ }
+
+ if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
+ return false;
+ }
+
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+ const vk_device& device = ggml_vk_get_device(ctx->device);
+
+ const uint32_t SPLIT_H = 16;
+
+ size_t stateC_size = SPLIT_H * d_state * sizeof(float);
+
+ if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
+ return false;
+ }
+
+ return true;
+ }
+ case GGML_OP_SSM_CONV:
+ return true;
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
struct ggml_context * ggml_ctx = ggml_init(iparams);
- std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
- std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0};
- std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
- const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"};
+ std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
+ std::array<size_t, GGML_MAX_SRC> src_size = {};
+ std::array<void *, GGML_MAX_SRC> src_buffer = {};
+ const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};
struct ggml_tensor * tensor_clone = nullptr;
- for (int i = 0; i < 6; i++) {
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
ggml_tensor * srci = tensor->src[i];
if (fused_rms_norm_mul) {
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
src_clone[2]);
} else if (tensor->op == GGML_OP_ADD_ID) {
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
+ } else if (tensor->op == GGML_OP_SSM_SCAN) {
+ tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
+ src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
+ } else if (tensor->op == GGML_OP_SSM_CONV) {
+ tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
}
else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
memcpy(comp_result, tensor_clone->data, comp_size);
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
- for (int i = 0; i < 6; i++) {
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
if (src_buffer[i] != nullptr) {
free(src_buffer[i]);
}
--- /dev/null
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+
+#include "types.glsl"
+
+layout(constant_id = 0) const uint D_STATE = 128;
+layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
+layout(constant_id = 2) const uint SPLIT_H = 16;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout(binding = 0) readonly buffer Src0 { float s0[]; };
+layout(binding = 1) readonly buffer Src1 { float x[]; };
+layout(binding = 2) readonly buffer Src2 { float dt[]; };
+layout(binding = 3) readonly buffer Src3 { float A[]; };
+layout(binding = 4) readonly buffer Src4 { float B[]; };
+layout(binding = 5) readonly buffer Src5 { float C[]; };
+layout(binding = 6) readonly buffer Src6 { int ids[]; };
+layout(binding = 7) buffer Dst { float d[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint nb02; uint nb03; uint nb12; uint nb13;
+ uint nb21; uint nb22; uint nb31;
+ uint nb42; uint nb43; uint nb52; uint nb53;
+ uint s_off;
+ uint n_head;
+ uint d_head;
+ uint n_group;
+ uint n_tok;
+};
+
+float softplus(float x) {
+ if (x <= 20.0) {
+ return log(1.0 + exp(x));
+ } else {
+ return x;
+ }
+}
+
+shared float stateC[SPLIT_H * D_STATE];
+
+void main() {
+ const uint tid = gl_LocalInvocationID.x;
+ const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
+ const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
+ const uint seq_idx = gl_WorkGroupID.y;
+
+ const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
+ const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
+ const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
+ const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
+ const uint A_base_idx = (head_idx * nb31) / 4;
+ const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
+ const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
+ const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
+ const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
+
+ const uint stride_x = nb12 / 4;
+ const uint stride_dt = nb21 / 4;
+ const uint stride_B = nb42 / 4;
+ const uint stride_C = nb52 / 4;
+ const uint stride_y = n_head * d_head;
+
+ float state[SPLIT_H];
+ [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
+ state[j] = s0[s0_base_idx + j * D_STATE + tid];
+ }
+
+ for (uint i = 0; i < n_tok; i++) {
+ const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
+
+ const float dA = exp(dt_soft_plus * A[A_base_idx]);
+
+ const float B_val = B[B_base_idx + i * stride_B + tid];
+ const float C_val = C[C_base_idx + i * stride_C + tid];
+
+ [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
+ const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
+
+ state[j] = (state[j] * dA) + (B_val * x_dt);
+
+ stateC[j * D_STATE + tid] = state[j] * C_val;
+ }
+
+ barrier();
+ for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
+ [[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
+ const uint k = (tid % (w >> 1)) +
+ (D_STATE * (tid / (w >> 1))) +
+ j * D_STATE * (D_STATE / (w >> 1));
+ if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
+ stateC[k] += stateC[k + (w >> 1)];
+ }
+ }
+ barrier();
+ }
+
+ [[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
+ const uint idx = (tid % SUBGROUP_SIZE) +
+ D_STATE * (tid / SUBGROUP_SIZE) +
+ j * D_STATE * (D_STATE / SUBGROUP_SIZE);
+
+ uint lane = tid % SUBGROUP_SIZE;
+
+ [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
+ if (idx + offset < SPLIT_H * D_STATE) {
+ stateC[idx] += stateC[idx + offset];
+ }
+ barrier();
+ }
+
+ if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
+ const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
+ d[y_base_idx + i * stride_y + k] = stateC[idx];
+ }
+ }
+
+ barrier();
+ }
+
+ [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
+ d[s_base_idx + j * D_STATE + tid] = state[j];
+ }
+}