vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
+ // [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
+ vk_pipeline pipeline_gated_delta_net[3][2];
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
uint32_t C;
uint32_t H;
};
+struct vk_op_gated_delta_net_push_constants {
+ uint32_t H;
+ uint32_t n_tokens;
+ uint32_t n_seqs;
+ uint32_t s_off;
+ uint32_t sq1, sq2, sq3;
+ uint32_t sv1, sv2, sv3;
+ uint32_t sb1, sb2, sb3;
+ uint32_t neq1, rq3;
+ float scale;
+};
+
struct vk_op_ssm_scan_push_constants {
uint32_t nb02, nb03, nb12, nb13;
uint32_t nb21, nb22, nb31;
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
+ {
+ const uint32_t gdn_sizes[] = {32, 64, 128};
+ const char * gdn_names[][2] = {
+ {"gated_delta_net_f32_d32", "gated_delta_net_f32_d32_kda"},
+ {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
+ {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
+ };
+ for (uint32_t si = 0; si < 3; si++) {
+ for (uint32_t kda = 0; kda < 2; kda++) {
+ ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
+ gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
+ "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
+ {1, 1, 1}, {gdn_sizes[si], kda}, 1);
+ }
+ }
+ }
+
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
+ case GGML_OP_GATED_DELTA_NET:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ const uint32_t S_v = dst->src[2]->ne[0];
+ const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0;
+ uint32_t si;
+ switch (S_v) {
+ case 32: si = 0; break;
+ case 64: si = 1; break;
+ case 128: si = 2; break;
+ default: return nullptr;
+ }
+ return ctx->device->pipeline_gated_delta_net[si][kda];
+ }
+ return nullptr;
case GGML_OP_SSM_SCAN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t d_state = src0->ne[0];
);
}
+static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
+ const ggml_tensor * src_q = dst->src[0];
+ const ggml_tensor * src_v = dst->src[2];
+ const ggml_tensor * src_beta = dst->src[4];
+
+ GGML_ASSERT(dst->buffer != nullptr);
+
+ const uint32_t S_v = (uint32_t)src_v->ne[0];
+ const uint32_t H = (uint32_t)src_v->ne[1];
+ const uint32_t n_tokens = (uint32_t)src_v->ne[2];
+ const uint32_t n_seqs = (uint32_t)src_v->ne[3];
+
+ const uint32_t s_off = S_v * H * n_tokens * n_seqs;
+
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
+ GGML_ASSERT(pipeline != nullptr);
+
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
+
+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
+ vk_subbuffer src_buf[6] = {};
+ for (int i = 0; i < 6; i++) {
+ src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
+ }
+
+ const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float));
+ const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float));
+ const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float));
+ const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float));
+ const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float));
+ const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float));
+ const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float));
+ const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float));
+ const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float));
+
+ const uint32_t neq1 = (uint32_t)src_q->ne[1];
+ const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]);
+
+ const float scale = 1.0f / sqrtf((float)S_v);
+ const vk_op_gated_delta_net_push_constants pc = {
+ H, n_tokens, n_seqs, s_off,
+ sq1, sq2, sq3,
+ sv1, sv2, sv3,
+ sb1, sb2, sb3,
+ neq1, rq3,
+ scale
+ };
+
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
+ {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
+ pc, { H, n_seqs, 1u });
+}
+
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
break;
+ case GGML_OP_GATED_DELTA_NET:
+ ggml_vk_gated_delta_net(ctx, compute_ctx, node);
+
+ break;
+
case GGML_OP_SSM_SCAN:
ggml_vk_ssm_scan(ctx, compute_ctx, node);
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true; // all inputs are contiguous, see ggml.c
+ case GGML_OP_GATED_DELTA_NET:
+ {
+ const uint32_t S_v = op->src[2]->ne[0];
+ if (S_v != 32 && S_v != 64 && S_v != 128) {
+ return false;
+ }
+ for (int i = 0; i < 6; i++) {
+ if (op->src[i] == nullptr || op->src[i]->type != GGML_TYPE_F32) {
+ return false;
+ }
+ }
+ return op->type == GGML_TYPE_F32;
+ }
case GGML_OP_SSM_SCAN:
{
for (int i = 0; i < 6; i++) {
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
src_clone[4], src_clone[5], src_clone[6]);
+ } else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
+ tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
+ src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = tensor->src[0]->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
--- /dev/null
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+
+layout(constant_id = 0) const uint S_V = 128;
+layout(constant_id = 1) const uint KDA = 0;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout(push_constant) uniform Parameters {
+ uint H;
+ uint n_tokens;
+ uint n_seqs;
+ uint s_off;
+ uint sq1, sq2, sq3;
+ uint sv1, sv2, sv3;
+ uint sb1, sb2, sb3;
+ uint neq1, rq3;
+ float scale;
+};
+
+layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; };
+layout(binding = 1) readonly buffer KBuf { FLOAT_TYPE data_k[]; };
+layout(binding = 2) readonly buffer VBuf { FLOAT_TYPE data_v[]; };
+layout(binding = 3) readonly buffer GBuf { FLOAT_TYPE data_g[]; };
+layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; };
+layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; };
+layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; };
+
+shared FLOAT_TYPE s_k[S_V];
+shared FLOAT_TYPE s_q[S_V];
+shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])
+
+void main() {
+ const uint head_id = gl_WorkGroupID.x;
+ const uint seq_id = gl_WorkGroupID.y;
+ const uint col = gl_LocalInvocationID.x;
+
+ const uint iq1 = head_id % neq1;
+ const uint iq3 = seq_id / rq3;
+
+ const uint state_size = S_V * S_V;
+ const uint state_base = (seq_id * H + head_id) * state_size;
+
+ FLOAT_TYPE state[S_V];
+ [[unroll]] for (uint i = 0; i < S_V; i++) {
+ state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]);
+ }
+
+ uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
+
+ for (uint t = 0; t < n_tokens; t++) {
+ const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
+ const uint k_off = q_off;
+ const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
+
+ s_q[col] = FLOAT_TYPE(data_q[q_off + col]);
+ s_k[col] = FLOAT_TYPE(data_k[k_off + col]);
+
+ const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
+
+ if (KDA != 0) {
+ const uint g_base = gb_off * S_V;
+ s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));
+ }
+
+ barrier();
+
+ const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
+ const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
+
+ if (KDA == 0) {
+ const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
+
+ FLOAT_TYPE kv_col = 0.0;
+ [[unroll]] for (uint i = 0; i < S_V; i += 4) {
+ kv_col += dot(
+ vec4(state[i], state[i+1], state[i+2], state[i+3]),
+ vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
+ );
+ }
+
+ FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;
+
+ FLOAT_TYPE attn_col = 0.0;
+ [[unroll]] for (uint i = 0; i < S_V; i += 4) {
+ vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
+ vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
+ sv = g_val * sv + kv * delta_col;
+ state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
+
+ attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
+ }
+
+ data_dst[attn_off + col] = attn_col * scale;
+ } else {
+ FLOAT_TYPE kv_col = 0.0;
+ [[unroll]] for (uint i = 0; i < S_V; i += 4) {
+ vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
+ vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
+ vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
+ kv_col += dot(gv * sv, kv);
+ }
+
+ FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
+
+ FLOAT_TYPE attn_col = 0.0;
+ [[unroll]] for (uint i = 0; i < S_V; i += 4) {
+ vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
+ vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
+ vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
+ sv = gv * sv + kv * delta_col;
+ state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
+
+ attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
+ }
+
+ data_dst[attn_off + col] = attn_col * scale;
+ }
+
+ attn_off += S_V * H;
+ barrier();
+ }
+
+ [[unroll]] for (uint i = 0; i < S_V; i++) {
+ data_dst[s_off + state_base + i * S_V + col] = state[i];
+ }
+}
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));
+ // GATED_DELTA_NET: realistic model configurations
+ // TG: n_seq_tokens=1 (autoregressive)
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1)); // Qwen3.5-like: 32 heads, d=128
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 1)); // smaller model
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1, 1, false, true)); // KDA
+ // PP: n_seq_tokens=64,256 (prompt processing)
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1)); // PP-64
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 256, 1)); // PP-256
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 512, 1)); // PP-512
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1024, 1)); // PP-1024
+ // Small model configs (fewer heads = less GPU occupancy for autoregressive)
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 64, 1)); // 4h PP-64
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 256, 1)); // 4h PP-256
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 512, 1)); // 4h PP-512
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 1024, 1)); // 4h PP-1024
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1, 1, false, true)); // KDA PP-64
+
return test_cases;
}