bool lower,
bool uni);
+ // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
+ // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
GGML_API struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
const float * state_in_base = (const float *)src_state->data;
- const int64_t rq1 = nev1 / neq1;
- const int64_t rk1 = nev1 / nek1;
+ //const int64_t rq1 = nev1 / neq1;
+ //const int64_t rk1 = nev1 / nek1;
const int64_t rq3 = nev3 / neq3;
const int64_t rk3 = nev3 / nek3;
const int64_t iv1 = ir % H; // head_index
const int64_t iv3 = ir / H; // sequence
- const int64_t iq1 = iv1 / rq1;
- const int64_t ik1 = iv1 / rk1;
+ const int64_t iq1 = iv1 % neq1;
+ const int64_t ik1 = iv1 % nek1;
const int64_t iq3 = iv3 / rq3;
const int64_t ik3 = iv3 / rk3;
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
- const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
+ const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
if (kda) {
for (int64_t i = 0; i < S_v; ++i) {
attn_data += S_v * H; // advance to next token
}
-
}
}
#include "gated_delta_net.cuh"
-#include "ggml-cuda/common.cuh"
template <int S_v, bool KDA>
-__global__ void __launch_bounds__(S_v, 1)
-gated_delta_net_cuda(const float * q,
- const float * k,
- const float * v,
- const float * g,
- const float * beta,
- const float * curr_state,
- float * dst,
- const int64_t H,
- const int64_t n_tokens,
- const int64_t n_seqs,
- const int64_t sq1,
- const int64_t sq2,
- const int64_t sq3,
- const int64_t sv1,
- const int64_t sv2,
- const int64_t sv3,
- const int64_t sb1,
- const int64_t sb2,
- const int64_t sb3,
- const int64_t rq1,
- const int64_t rq3,
- const float scale) {
- const int64_t h_idx = blockIdx.x;
- const int64_t sequence = blockIdx.y;
- const int col = threadIdx.x; // each thread owns one column
-
- const int64_t iq1 = h_idx / rq1;
- const int64_t iq3 = sequence / rq3;
+__global__ void gated_delta_net_cuda(const float * q,
+ const float * k,
+ const float * v,
+ const float * g,
+ const float * beta,
+ const float * curr_state,
+ float * dst,
+ int64_t H,
+ int64_t n_tokens,
+ int64_t n_seqs,
+ int64_t sq1,
+ int64_t sq2,
+ int64_t sq3,
+ int64_t sv1,
+ int64_t sv2,
+ int64_t sv3,
+ int64_t sb1,
+ int64_t sb2,
+ int64_t sb3,
+ const uint3 neqk1_magic,
+ const uint3 rq3_magic,
+ float scale) {
+ const uint32_t h_idx = blockIdx.x;
+ const uint32_t sequence = blockIdx.y;
+ // each warp owns one column, using warp-level primitives to reduce across rows
+ const int lane = threadIdx.x;
+ const int col = blockIdx.z * blockDim.y + threadIdx.y;
+
+ const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
+ const uint32_t iq3 = fastdiv(sequence, rq3_magic);
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_data = dst;
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
- // GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229
- // TODO: check optimal path for RDNA1 and RDNA2 devices.
-#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA)
- extern __shared__ float s_shared[];
- float * s = s_shared + col * S_v;
-#else
- float s[S_v];
-#endif
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
+ static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
+ constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
+ float s_shard[rows_per_lane];
#pragma unroll
- for (int i = 0; i < S_v; i++) {
- s[i] = curr_state[i * S_v + col];
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ s_shard[r] = curr_state[i * S_v + col];
}
for (int t = 0; t < n_tokens; t++) {
const float g_val = expf(*g_t);
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
- float kv_col = 0.0f;
+ float kv_shard = 0.0f;
#pragma unroll
- for (int i = 0; i < S_v; i++) {
- kv_col += s[i] * k_t[i];
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ kv_shard += s_shard[r] * k_t[i];
}
+ float kv_col = warp_reduce_sum<warp_size>(kv_shard);
// delta[col] = (v[col] - g * kv[col]) * beta
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
// fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
- float attn_col = 0.0f;
+ float attn_partial = 0.0f;
#pragma unroll
- for (int i = 0; i < S_v; i++) {
- s[i] = g_val * s[i] + k_t[i] * delta_col;
- attn_col += s[i] * q_t[i];
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
+ attn_partial += s_shard[r] * q_t[i];
}
- attn_data[col] = attn_col * scale;
+ float attn_col = warp_reduce_sum<warp_size>(attn_partial);
+
+ if (lane == 0) {
+ attn_data[col] = attn_col * scale;
+ }
} else {
// kv[col] = sum_i g[i] * S[i][col] * k[i]
- float kv_col = 0.0f;
+ float kv_shard = 0.0f;
#pragma unroll
- for (int i = 0; i < S_v; i++) {
- kv_col += expf(g_t[i]) * s[i] * k_t[i];
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
}
+ float kv_col = warp_reduce_sum<warp_size>(kv_shard);
+
// delta[col] = (v[col] - kv[col]) * beta
float delta_col = (v_t[col] - kv_col) * beta_val;
// fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
- float attn_col = 0.0f;
+ float attn_partial = 0.0f;
#pragma unroll
- for (int i = 0; i < S_v; i++) {
- s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col;
- attn_col += s[i] * q_t[i];
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
+ attn_partial += s_shard[r] * q_t[i];
}
- attn_data[col] = attn_col * scale;
+ float attn_col = warp_reduce_sum<warp_size>(attn_partial);
+
+ if (lane == 0) {
+ attn_data[col] = attn_col * scale;
+ }
}
attn_data += S_v * H;
// Write state back to global memory
#pragma unroll
- for (int i = 0; i < S_v; i++) {
- state[i * S_v + col] = s[i];
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ state[i * S_v + col] = s_shard[r];
}
}
const float * q_d, const float * k_d, const float * v_d,
const float * g_d, const float * b_d, const float * s_d,
float * dst_d,
- int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
- int64_t sq1, int64_t sq2, int64_t sq3,
- int64_t sv1, int64_t sv2, int64_t sv3,
- int64_t sb1, int64_t sb2, int64_t sb3,
- int64_t rq1, int64_t rq3,
+ int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
+ int64_t sq1, int64_t sq2, int64_t sq3,
+ int64_t sv1, int64_t sv2, int64_t sv3,
+ int64_t sb1, int64_t sb2, int64_t sb3,
+ int64_t neqk1, int64_t rq3,
float scale, cudaStream_t stream) {
+ //TODO: Add chunked kernel for even faster pre-fill
+ const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
+ const int num_warps = 4;
+ dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
+ dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
- dim3 grid_dims(H, n_seqs, 1);
- dim3 block_dims(S_v, 1, 1);
+ const uint3 neqk1_magic = init_fastdiv_values(neqk1);
+ const uint3 rq3_magic = init_fastdiv_values(rq3);
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
switch (S_v) {
- case 32: {
- constexpr int sv = 32;
- size_t smem = calculate_smem(sv, cc);
- gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
+ case 16:
+ gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
- sb1, sb2, sb3, rq1, rq3, scale);
+ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+ break;
+ case 32:
+ gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
+ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
+ n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
- }
case 64: {
constexpr int sv = 64;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
- sb1, sb2, sb3, rq1, rq3, scale);
+ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
}
case 128: {
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
- sb1, sb2, sb3, rq1, rq3, scale);
+ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
}
default:
ggml_tensor * src_state = dst->src[5];
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
- GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
+ GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
+ GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
+ GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
- GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
- GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
+ GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
+ GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
const int64_t S_v = nev0;
const int64_t H = nev1;
const bool kda = (src_g->ne[0] == S_v);
- const int64_t rq1 = nev1 / neq1;
+ GGML_ASSERT(neq1 == nek1);
+ const int64_t neqk1 = neq1;
+
const int64_t rq3 = nev3 / neq3;
const float * q_d = (const float *) src_q->data;
if (kda) {
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
- sb1, sb2, sb3, rq1, rq3, scale, stream);
+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
} else {
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
- sb1, sb2, sb3, rq1, rq3, scale, stream);
+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
}
}
return res;
}
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) {
+ char base[256];
+ char name[256];
+
+ // v is src[2], dimensions: S_v = ne[0], H = ne[1]
+ const int ne20 = op->src[2]->ne[0]; // S_v
+ const int ne21 = op->src[2]->ne[1]; // H
+ const int ne30 = op->src[3]->ne[0]; // G
+
+ const int nsg = op->src[2]->ne[0]/32;
+
+ GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->ne[0] == ne20 * ne21);
+ GGML_ASSERT(ne20 % 32 == 0);
+
+ snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
+ snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+ ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
+ ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+ ggml_metal_cv_free(cv);
+ }
+
+ res.nsg = nsg;
+
+ return res;
+}
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true;
+ case GGML_OP_GATED_DELTA_NET:
+ return op->src[2]->ne[0] % 32 == 0;
case GGML_OP_SOLVE_TRI:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
#define FC_BIN 1300
#define FC_SUM_ROWS 1400
#define FC_UPSCALE 1500
+#define FC_GATED_DELTA_NET 1600
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
uint64_t nb0;
} ggml_metal_kargs_ssm_scan;
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ int32_t ne11;
+ int32_t ne12;
+ int32_t ne13;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ne20;
+ int32_t ne21;
+ int32_t ne22;
+ int32_t ne23;
+ uint64_t nb20;
+ uint64_t nb21;
+ uint64_t nb22;
+ uint64_t nb23;
+ int32_t ns02;
+ int32_t ns12;
+ int32_t ns22;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_gated_delta_net;
+
typedef struct {
int32_t ne00;
int32_t ne01;
{
n_fuse = ggml_metal_op_rwkv(ctx, idx);
} break;
+ case GGML_OP_GATED_DELTA_NET:
+ {
+ n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
+ } break;
case GGML_OP_SOLVE_TRI:
{
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
return 1;
}
+int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);
+
+ int ida = 0;
+
+ ggml_metal_kargs_gated_delta_net args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.ne13 =*/ ne13,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne20 =*/ ne20,
+ /*.ne21 =*/ ne21,
+ /*.ne22 =*/ ne22,
+ /*.ne23 =*/ ne23,
+ /*.nb20 =*/ nb20,
+ /*.nb21 =*/ nb21,
+ /*.nb22 =*/ nb22,
+ /*.nb23 =*/ nb23,
+ /*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
+ /*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
+ /*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst
+
+ const int nsg = pipeline.nsg;
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);
+
+ return 1;
+}
+
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_gated_delta_net (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
}
}
+constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
+constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
+
+#if 1
+template<short NSG>
+kernel void kernel_gated_delta_net_impl(
+ constant ggml_metal_kargs_gated_delta_net & args,
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * g,
+ device const char * b,
+ device const char * s,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+#define S_v FC_gated_delta_net_ne20
+#define G FC_gated_delta_net_ne30
+
+ const uint tx = tpitg.x;
+ const uint ty = tpitg.y;
+
+ const uint i23 = tgpig.z; // B
+ const uint i21 = tgpig.y; // H
+ const uint i20 = tgpig.x*NSG + ty;
+
+ const uint i01 = i21 % args.ne01;
+ const uint i11 = i21 % args.ne11;
+
+ const float scale = 1.0f / sqrt((float)S_v);
+
+ device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
+
+ float ls[NSG];
+
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
+ const short is = tx*NSG + j;
+ ls[j] = s_ptr[is*S_v];
+ }
+
+ device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
+
+ device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
+ device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
+ device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
+
+ device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
+ device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
+
+ for (short t = 0; t < args.ne22; t++) {
+ float s_k = 0.0f;
+
+ if (G == 1) {
+ const float g_exp = exp(g_ptr[0]);
+
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
+ const short is = tx*NSG + j;
+ ls[j] *= g_exp;
+
+ s_k += ls[j]*k_ptr[is];
+ }
+ } else {
+ // KDA
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
+ const short is = tx*NSG + j;
+ ls[j] *= exp(g_ptr[is]);
+
+ s_k += ls[j]*k_ptr[is];
+ }
+ }
+
+ s_k = simd_sum(s_k);
+
+ const float d = (v_ptr[i20] - s_k)*b_ptr[0];
+
+ float y = 0.0f;
+
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
+ const short is = tx*NSG + j;
+ ls[j] += k_ptr[is]*d;
+
+ y += ls[j]*q_ptr[is];
+ }
+
+ y = simd_sum(y);
+
+ if (tx == 0) {
+ dst_attn[t*args.ne21*S_v] = y*scale;
+ }
+
+ q_ptr += args.ns02;
+ k_ptr += args.ns12;
+ v_ptr += args.ns22;
+
+ b_ptr += args.ne21;
+ g_ptr += args.ne21*G;
+ }
+
+ device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
+
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
+ const short is = tx*NSG + j;
+ dst_state[is*S_v] = ls[j];
+ }
+
+#undef S_v
+#undef G
+}
+
+typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
+
+template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
+template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
+template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
+
+#else
+// a simplified version of the above
+// no performance improvement, so keep the above version for now
+
+template<typename T, short NSG>
+kernel void kernel_gated_delta_net_impl(
+ constant ggml_metal_kargs_gated_delta_net & args,
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * g,
+ device const char * b,
+ device const char * s,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+#define S_v FC_gated_delta_net_ne20
+#define G FC_gated_delta_net_ne30
+
+ const uint tx = tpitg.x;
+ const uint ty = tpitg.y;
+
+ const uint i23 = tgpig.z; // B
+ const uint i21 = tgpig.y; // H
+ const uint i20 = tgpig.x*NSG + ty;
+
+ const uint i01 = i21 % args.ne01;
+ const uint i11 = i21 % args.ne11;
+
+ const float scale = 1.0f / sqrt((float)S_v);
+
+ device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
+
+ float lsf[NSG];
+
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
+ const short is = tx*NSG + j;
+ lsf[j] = s_ptr[is*S_v];
+ }
+
+ thread T * ls = (thread T *) (lsf);
+
+ device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
+
+ device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
+ device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
+ device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
+
+ device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
+ device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
+
+ for (short t = 0; t < args.ne22; t++) {
+ device const T * qt_ptr = (device const T *) (q_ptr);
+ device const T * kt_ptr = (device const T *) (k_ptr);
+ device const T * gt_ptr = (device const T *) (g_ptr);
+
+ if (G == 1) {
+ *ls *= exp(g_ptr[0]);
+ } else {
+ // KDA
+ *ls *= exp(gt_ptr[tx]);
+ }
+
+ const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
+
+ const float d = (v_ptr[i20] - s_k)*b_ptr[0];
+
+ *ls += kt_ptr[tx]*d;
+
+ const float y = simd_sum(dot(*ls, qt_ptr[tx]));
+
+ if (tx == 0) {
+ *dst_attn = y*scale;
+ }
+
+ q_ptr += args.ns02;
+ k_ptr += args.ns12;
+ v_ptr += args.ns22;
+
+ b_ptr += args.ne21;
+ g_ptr += args.ne21*G;
+
+ dst_attn += args.ne21*S_v;
+ }
+
+ device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
+ device T * dstt_state = (device T *) (dst_state);
+
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
+ const short is = tx*NSG + j;
+ dst_state[is*S_v] = lsf[j];
+ }
+
+#undef S_v
+#undef G
+}
+
+typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
+
+template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
+template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
+template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
+#endif
+
constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
}
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1));
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true));
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));
// KDA (vector gate)
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 2, 1, false, true));
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 1, 2, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true, true));
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 4, 2, 1, true, true));
#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging