const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
+ // state is stored transposed: s_out[j*S_v + i] = S[i][j]
+ // so row j of s_out = column j of S (contiguous access)
+
if (kda) {
+ // precompute exp(g) into delta scratch (reused below)
for (int64_t i = 0; i < S_v; ++i) {
- ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i]));
+ delta[i] = expf(g_d[i]);
+ }
+ // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
+ for (int64_t j = 0; j < S_v; ++j) {
+ ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
}
} else {
ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
}
- // delta[j] = sum_i S[j][i] * k[i]
- memset(delta, 0, S_v * sizeof(float));
- for (int64_t i = 0; i < S_v; ++i) {
- ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]);
- }
+ // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
for (int64_t j = 0; j < S_v; ++j) {
- delta[j] = (v_d[j] - delta[j]) * beta_val;
+ float sum = 0.0f;
+ ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
+ delta[j] = (v_d[j] - sum) * beta_val;
}
- // outer product: S[j][i] += k[i] * delta[j]
- for (int64_t i = 0; i < S_v; ++i) {
- ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]);
+ // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
+ for (int64_t j = 0; j < S_v; ++j) {
+ ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
}
- // attn_out[j] = sum_i S[j][i] * q[i]
- memset(attn_data, 0, S_v * sizeof(float));
- for (int64_t i = 0; i < S_v; ++i) {
- ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]);
+ // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
+ for (int64_t j = 0; j < S_v; ++j) {
+ float sum = 0.0f;
+ ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
+ attn_data[j] = sum * scale;
}
- ggml_vec_scale_f32(S_v, attn_data, scale);
attn_data += S_v * H; // advance to next token
}
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
float s_shard[rows_per_lane];
+ // state is stored transposed: M[col][i] = S[i][col], row col is contiguous
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
- s_shard[r] = curr_state[i * S_v + col];
+ s_shard[r] = curr_state[col * S_v + i];
}
for (int t = 0; t < n_tokens; t++) {
attn_data += S_v * H;
}
- // Write state back to global memory
+ // Write state back to global memory (transposed layout)
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
- state[i * S_v + col] = s_shard[r];
+ state[col * S_v + i] = s_shard[r];
}
}
-static size_t calculate_smem(const int sv, int cc)
-{
- size_t smem = 0;
- if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
- smem = sv * sv * sizeof(float);
- }
- return smem;
-}
-
template <bool KDA>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 64: {
- constexpr int sv = 64;
- size_t smem = calculate_smem(sv, cc);
- gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
+ gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
}
case 128: {
- constexpr int sv = 128;
- size_t smem = calculate_smem(sv, cc);
- gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
+ gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
const float scale = 1.0f / sqrt((float)S_v);
- device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
+ // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
+ device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
float ls[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
- ls[j] = s_ptr[is*S_v];
+ ls[j] = s_ptr[is];
}
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
g_ptr += args.ne21*G;
}
- device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
+ device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
- dst_state[is*S_v] = ls[j];
+ dst_state[is] = ls[j];
}
#undef S_v
FLOAT_TYPE state[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
- state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]);
+ state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
}
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
}
[[unroll]] for (uint i = 0; i < S_V; i++) {
- data_dst[s_off + state_base + i * S_V + col] = state[i];
+ data_dst[s_off + state_base + col * S_V + i] = state[i];
}
}
ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
cb(kg_t, "key_gdiff_t", il);
- ggml_tensor * s_t = ggml_transpose(ctx0, s);
- s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
- cb(s_t, "dnet_add_ch_state", il);
+ s = ggml_reshape_4d(ctx0, s, S_v, S_v, 1, H_v * n_seqs);
+ cb(s, "dnet_add_ch_state", il);
// [CS, S_v, n_chunks, H_v * n_seqs]
ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs]
// [CS, S_v, 1, H_v * n_seqs]
- ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
+ ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s);
cb(v_t_p, "v_prime", il);
// [CS, S_v, 1, H_v * n_seqs]
cb(v_attn, "v_attn", il);
// [S_v, CS, 1, H_v * n_seqs]
- ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
+ ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s, ch_q_g_exp);
cb(attn_inter, "attn_inter", il);
// [S_v, CS, 1, H_v * n_seqs]
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk);
- s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t);
- s_t = ggml_add(ctx0, s_t, kgv);
- cb(s_t, "dnet_add_ch_state", il);
+ s = ggml_mul(ctx0, s, ch_g_last_exp_t);
+ s = ggml_add(ctx0, s, kgv);
+ cb(s, "dnet_add_ch_state", il);
}
- s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
-
// truncate padded tokens
ggml_tensor * o = ggml_view_4d(ctx0, v,
S_v, n_tokens, H_v, n_seqs,
ggml_row_size(v->type, S_v * CS * n_chunks),
ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
- s = ggml_transpose(ctx0, s_t);
+ s = ggml_reshape_4d(ctx0, s, S_v, S_v, H_v, n_seqs);
cb(s, "output_state", il);
return {o, s};
g = ggml_exp(ctx0, g);
s = ggml_mul(ctx0, s, g);
- ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
-
// [1, S_v, H_v, n_seqs]
ggml_tensor * sk;
- sk = ggml_mul (ctx0, s_t, k);
+ sk = ggml_mul (ctx0, s, k);
sk = ggml_sum_rows(ctx0, sk);
// [S_v, 1, H_v, n_seqs]
k = ggml_repeat(ctx0, k, s);
kd = ggml_mul (ctx0, k, d_t);
- s_t = ggml_add(ctx0, s_t, kd);
+ s = ggml_add(ctx0, s, kd);
- cb(s_t, "dnet_add_ar_state", il);
+ cb(s, "dnet_add_ar_state", il);
- ggml_tensor * s_q = ggml_mul (ctx0, s_t, q);
+ ggml_tensor * s_q = ggml_mul (ctx0, s, q);
ggml_tensor * o = ggml_sum_rows(ctx0, s_q);
o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
- s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
return {o, s};
}