#include "ggml-cuda/common.cuh"
template <int S_v, bool KDA>
-__global__ void gated_delta_net_cuda(const float * q,
- const float * k,
- const float * v,
- const float * g,
- const float * beta,
- const float * curr_state,
- float * dst,
- int64_t H,
- int64_t n_tokens,
- int64_t n_seqs,
- int64_t sq1,
- int64_t sq2,
- int64_t sq3,
- int64_t sv1,
- int64_t sv2,
- int64_t sv3,
- int64_t sb1,
- int64_t sb2,
- int64_t sb3,
- int64_t rq1,
- int64_t rq3,
- float scale) {
+__global__ void __launch_bounds__(S_v, 1)
+gated_delta_net_cuda(const float * q,
+ const float * k,
+ const float * v,
+ const float * g,
+ const float * beta,
+ const float * curr_state,
+ float * dst,
+ const int64_t H,
+ const int64_t n_tokens,
+ const int64_t n_seqs,
+ const int64_t sq1,
+ const int64_t sq2,
+ const int64_t sq3,
+ const int64_t sv1,
+ const int64_t sv2,
+ const int64_t sv3,
+ const int64_t sb1,
+ const int64_t sb2,
+ const int64_t sb3,
+ const int64_t rq1,
+ const int64_t rq3,
+ const float scale) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
- // Load state column into registers
+ // GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229
+ // TODO: check optimal path for RDNA1 and RDNA2 devices.
+#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA)
+ extern __shared__ float s_shared[];
+ float * s = s_shared + col * S_v;
+#else
float s[S_v];
+#endif
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = curr_state[i * S_v + col];
}
}
+static size_t calculate_smem(const int sv, int cc)
+{
+ size_t smem = 0;
+ if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
+ smem = sv * sv * sizeof(float);
+ }
+ return smem;
+}
+
template <bool KDA>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
dim3 grid_dims(H, n_seqs, 1);
dim3 block_dims(S_v, 1, 1);
+ int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
switch (S_v) {
- case 32:
- gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
+ case 32: {
+ constexpr int sv = 32;
+ size_t smem = calculate_smem(sv, cc);
+ gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
- case 64:
- gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
+ }
+ case 64: {
+ constexpr int sv = 64;
+ size_t smem = calculate_smem(sv, cc);
+ gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
- case 128:
- gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
+ }
+ case 128: {
+ constexpr int sv = 128;
+ size_t smem = calculate_smem(sv, cc);
+ gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
+ }
default:
GGML_ABORT("fatal error");
break;