* ggml : transpose fused GDN state access for coalesced memory reads (llama/20436)
The fused Gated Delta Net kernel accessed the [S_v, S_v] state matrix
column-wise on row-major storage, causing strided reads (stride S_v =
128 floats = 512 bytes) that waste GPU cache bandwidth. This produced a
39% regression on Qwen3.5-9B (Metal, M4 Max) compared to the unfused
path.
Transpose the state indexing so threads read contiguously:
- Metal: s_ptr[is*S_v] -> s_ptr[is] (stride 1 vs S_v)
- CUDA: curr_state[i*S_v+col] -> curr_state[col*S_v+i] (coalesced)
- CPU: restructured loops for row-wise transposed access
Also add --fused-gdn [on|off|auto] CLI flag (mirrors --flash-attn) so
users can control fused GDN independently of auto-detection.
All GATED_DELTA_NET backend-ops tests pass.
Co-Authored-By: Claude Opus 4.6 <redacted>
* ggml : use SIMD dot products in CPU GDN kernel, couple AR/chunked fused flags
- Replace scalar inner loops with ggml_vec_dot_f32 for SIMD-optimized
dot products in the CPU fused GDN kernel (delta and attention output)
- Couple fused_gdn_ar and fused_gdn_ch flags in auto-detection: if one
path lacks device support, disable both to prevent state layout mismatch
between transposed (fused) and non-transposed (unfused) formats
Co-Authored-By: Claude Opus 4.6 <redacted>
* llama : rever fgdn argument changes
* graph : remove GDN state transposes
* vulkan : adapt
* cuda : remove obsolete smem code
---------
Co-authored-by: Paul Flynn <redacted>
Co-authored-by: Claude Opus 4.6 <redacted>
Co-authored-by: Oliver Simons <redacted>
const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
+ // state is stored transposed: s_out[j*S_v + i] = S[i][j]
+ // so row j of s_out = column j of S (contiguous access)
+
if (kda) {
+ // precompute exp(g) into delta scratch (reused below)
for (int64_t i = 0; i < S_v; ++i) {
- ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i]));
+ delta[i] = expf(g_d[i]);
+ }
+ // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
+ for (int64_t j = 0; j < S_v; ++j) {
+ ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
}
} else {
ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
}
- // delta[j] = sum_i S[j][i] * k[i]
- memset(delta, 0, S_v * sizeof(float));
- for (int64_t i = 0; i < S_v; ++i) {
- ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]);
- }
+ // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
for (int64_t j = 0; j < S_v; ++j) {
- delta[j] = (v_d[j] - delta[j]) * beta_val;
+ float sum = 0.0f;
+ ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
+ delta[j] = (v_d[j] - sum) * beta_val;
}
- // outer product: S[j][i] += k[i] * delta[j]
- for (int64_t i = 0; i < S_v; ++i) {
- ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]);
+ // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
+ for (int64_t j = 0; j < S_v; ++j) {
+ ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
}
- // attn_out[j] = sum_i S[j][i] * q[i]
- memset(attn_data, 0, S_v * sizeof(float));
- for (int64_t i = 0; i < S_v; ++i) {
- ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]);
+ // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
+ for (int64_t j = 0; j < S_v; ++j) {
+ float sum = 0.0f;
+ ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
+ attn_data[j] = sum * scale;
}
- ggml_vec_scale_f32(S_v, attn_data, scale);
attn_data += S_v * H; // advance to next token
}
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
float s_shard[rows_per_lane];
+ // state is stored transposed: M[col][i] = S[i][col], row col is contiguous
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
- s_shard[r] = curr_state[i * S_v + col];
+ s_shard[r] = curr_state[col * S_v + i];
}
for (int t = 0; t < n_tokens; t++) {
attn_data += S_v * H;
}
- // Write state back to global memory
+ // Write state back to global memory (transposed layout)
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
- state[i * S_v + col] = s_shard[r];
+ state[col * S_v + i] = s_shard[r];
}
}
-static size_t calculate_smem(const int sv, int cc)
-{
- size_t smem = 0;
- if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
- smem = sv * sv * sizeof(float);
- }
- return smem;
-}
-
template <bool KDA>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 64: {
- constexpr int sv = 64;
- size_t smem = calculate_smem(sv, cc);
- gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
+ gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
}
case 128: {
- constexpr int sv = 128;
- size_t smem = calculate_smem(sv, cc);
- gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
+ gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
const float scale = 1.0f / sqrt((float)S_v);
- device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
+ // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
+ device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
float ls[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
- ls[j] = s_ptr[is*S_v];
+ ls[j] = s_ptr[is];
}
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
g_ptr += args.ne21*G;
}
- device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
+ device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
- dst_state[is*S_v] = ls[j];
+ dst_state[is] = ls[j];
}
#undef S_v
FLOAT_TYPE state[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
- state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]);
+ state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
}
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
}
[[unroll]] for (uint i = 0; i < S_V; i++) {
- data_dst[s_off + state_base + i * S_V + col] = state[i];
+ data_dst[s_off + state_base + col * S_V + i] = state[i];
}
}