]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : faster ssm scan (#10558)
authora3sh <redacted>
Mon, 31 Mar 2025 16:05:13 +0000 (00:05 +0800)
committerGitHub <redacted>
Mon, 31 Mar 2025 16:05:13 +0000 (18:05 +0200)
* faster ssm_scan

* delete unused commnet

* clang format

* add space

* modify unnecessary calculations

* faster ssm conv implementatioin

* modify file name with dash

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/ssm-conv.cu [new file with mode: 0644]
ggml/src/ggml-cuda/ssm-conv.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/ssm-scan.cu [new file with mode: 0644]
ggml/src/ggml-cuda/ssm-scan.cuh [new file with mode: 0644]

index f2ad692f6617e62c8fa3dc866bd19dcde2a3c125..861927654ec9838fb41c15db30fe377524e55eac 100644 (file)
@@ -31,6 +31,8 @@
 #include "ggml-cuda/rope.cuh"
 #include "ggml-cuda/scale.cuh"
 #include "ggml-cuda/softmax.cuh"
+#include "ggml-cuda/ssm-conv.cuh"
+#include "ggml-cuda/ssm-scan.cuh"
 #include "ggml-cuda/sum.cuh"
 #include "ggml-cuda/sumrows.cuh"
 #include "ggml-cuda/tsembd.cuh"
@@ -2296,6 +2298,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SUM_ROWS:
             ggml_cuda_op_sum_rows(ctx, dst);
             break;
+        case GGML_OP_SSM_CONV:
+            ggml_cuda_op_ssm_conv(ctx, dst);
+            break;
+        case GGML_OP_SSM_SCAN:
+            ggml_cuda_op_ssm_scan(ctx, dst);
+            break;
         case GGML_OP_ARGSORT:
             ggml_cuda_op_argsort(ctx, dst);
             break;
@@ -3193,6 +3201,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_COS:
         case GGML_OP_CLAMP:
         case GGML_OP_LOG:
+        case GGML_OP_SSM_SCAN:
+        case GGML_OP_SSM_CONV:
             return true;
         case GGML_OP_CONT:
             return op->src[0]->type != GGML_TYPE_BF16;
diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu
new file mode 100644 (file)
index 0000000..cfe03d6
--- /dev/null
@@ -0,0 +1,151 @@
+#include "ssm-conv.cuh"
+
+template <size_t split_d_inner, size_t d_conv>
+static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
+                                    const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
+                                    float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
+                                    const int nc, const int ncs, const int nr, const int n_t, const int n_s) {
+    const int tid  = threadIdx.x;
+    const int bidx = blockIdx.x;
+    const int bidy = blockIdx.y;
+
+    const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
+    const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
+    float *       y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
+
+    const int stride_x = src0_nb1 / sizeof(float);
+    const int stride_w = src1_nb1 / sizeof(float);
+    const int stride_y = dst_nb1 / sizeof(float);
+
+    float x[d_conv] = { 0.0f };
+    float w[d_conv] = { 0.0f };
+
+#pragma unroll
+    for (int j = 0; j < d_conv; j++) {
+        w[j] = w_block[tid * stride_w + j];
+    }
+
+    for (int i = 0; i < n_t; i++) {
+        float sumf = 0.0f;
+
+        if (i == 0) {
+            for (int j = 0; j < d_conv; j++) {
+                x[j] = x_block[tid * stride_x + j];
+            }
+        } else {
+            x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
+        }
+
+#pragma unroll
+        for (int j = 0; j < d_conv; j++) {
+            sumf += x[(i + j) % d_conv] * w[j];
+        }
+        y_block[i * stride_y + tid] = sumf;
+    }
+}
+
+template <size_t split_d_inner, size_t d_conv, size_t split_n_t>
+static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
+                                               const int src0_nb0, const int src0_nb1, const int src0_nb2,
+                                               const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
+                                               const int dst_nb1, const int dst_nb2, const int nc, const int ncs,
+                                               const int nr, const int n_t, const int n_s) {
+    const int tid  = threadIdx.x;
+    const int bidx = blockIdx.x;
+    const int bidy = blockIdx.y;
+    const int bidz = blockIdx.z;
+
+    const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
+                                             bidz * split_n_t * src0_nb0);
+    const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
+    float *       y_block =
+        (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
+
+    const int stride_x = src0_nb1 / sizeof(float);
+    const int stride_w = src1_nb1 / sizeof(float);
+    const int stride_y = dst_nb1 / sizeof(float);
+
+    float x[d_conv] = { 0.0f };
+    float w[d_conv] = { 0.0f };
+
+#pragma unroll
+    for (int j = 0; j < d_conv; j++) {
+        w[j] = w_block[tid * stride_w + j];
+    }
+
+#pragma unroll
+    for (int i = 0; i < split_n_t; i++) {
+        if (bidz * split_n_t + i < n_t) {
+            float sumf = 0.0f;
+
+            if (i == 0) {
+                for (int j = 0; j < d_conv; j++) {
+                    x[j] = x_block[tid * stride_x + j];
+                }
+            } else {
+                x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
+            }
+
+#pragma unroll
+            for (int j = 0; j < d_conv; j++) {
+                sumf += x[(i + j) % d_conv] * w[j];
+            }
+            y_block[i * stride_y + tid] = sumf;
+        }
+    }
+}
+
+static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
+                              const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
+                              const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t,
+                              const int n_s, cudaStream_t stream) {
+    const int threads = 128;
+    GGML_ASSERT(nr % threads == 0);
+
+    if (n_t <= 32) {
+        const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
+        if (nc == 4) {
+            ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
+                                                                     dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t,
+                                                                     n_s);
+        } else {
+            GGML_ABORT("Only support kernel size = 4  now.");
+        }
+    } else {
+        if (nc == 4) {
+            const int split_n_t = 32;
+            dim3      blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
+            ssm_conv_long_token_f32<threads, 4, split_n_t>
+                <<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0,
+                                                 dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s);
+        } else {
+            GGML_ABORT("Only support kernel size = 4 right now.");
+        }
+    }
+}
+
+void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];  // conv_x
+    const struct ggml_tensor * src1 = dst->src[1];  // conv1d.weight
+
+    const int nc  = src1->ne[0];                    // d_conv
+    const int ncs = src0->ne[0];                    // d_conv - 1 + n_t
+    const int nr  = src0->ne[1];                    // d_inner
+    const int n_t = dst->ne[1];                     // tokens per sequence
+    const int n_s = dst->ne[2];                     // number of sequences in the batch
+
+    GGML_ASSERT(dst->ne[0] == nr);
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float *       dst_d  = (float *) dst->data;
+    cudaStream_t  stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
+                      dst->nb[2], nc, ncs, nr, n_t, n_s, stream);
+}
diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh
new file mode 100644 (file)
index 0000000..8e6c1f0
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu
new file mode 100644 (file)
index 0000000..52db17c
--- /dev/null
@@ -0,0 +1,155 @@
+#include "ssm-scan.cuh"
+
+// #include <cuda_runtime.h>
+// static __device__ void global_to_shared(const float *src, float *dst) {
+//   asm volatile("cp.async.");
+// }
+
+template <size_t splitD, size_t N>
+__global__ void __launch_bounds__(splitD, 2)
+    ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
+                 const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
+                 const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
+                 const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
+                 const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
+                 float * __restrict__ dst, const int D, const int L, const int B) {
+    const int bidx = blockIdx.x;  // split along B
+    const int bidy = blockIdx.y;  // split along D
+    const int tid  = threadIdx.x;
+    const int wid  = tid / 32;
+    const int wtid = tid % 32;
+
+    extern __shared__ float smem[];
+    const int               stride_sA  = N + 1;
+    const int               stride_ss0 = N + 1;
+    float *                 smem_A     = smem;
+    float *                 smem_s0    = smem_A + splitD * stride_sA;
+
+    const float * s0_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
+    const float * x_block  = (const float *) ((char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
+    const float * dt_block = (const float *) ((char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
+    const float * A_block  = (const float *) ((char *) src3 + bidy * splitD * src3_nb1);
+    const float * B_block  = (const float *) ((char *) src4 + (bidx * src4_nb2));
+    const float * C_block  = (const float *) ((char *) src5 + (bidx * src5_nb2));
+    float *       y_block  = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
+    float *       s_block  = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
+
+    const int stride_s0 = src0_nb1 / sizeof(float);
+    const int stride_x  = src1_nb1 / sizeof(float);
+    const int stride_dt = src2_nb1 / sizeof(float);
+    const int stride_A  = src3_nb1 / sizeof(float);
+    const int stride_B  = src4_nb1 / sizeof(float);
+    const int stride_C  = src5_nb1 / sizeof(float);
+    const int stride_s  = stride_s0;
+    const int stride_y  = stride_x;
+
+    // can N not be 16? for example 32?
+    if (N == 16) {
+#pragma unroll
+        for (int i = 0; i < splitD / 4; i += 2) {
+            float value = A_block[(wid * warpSize + i) * stride_A + wtid];
+            // todo: bank conflict
+            // I am always confused with how to use the swizzling method to solve
+            // bank conflit. Hoping somebody can tell me.
+            smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
+        }
+#pragma unroll
+        for (int i = 0; i < splitD / 4; i += 2) {
+            float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
+            smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
+        }
+    }
+
+    __syncthreads();
+
+    for (int i = 0; i < L; i++) {
+        float dt_soft_plus = dt_block[i * stride_dt + tid];
+        if (dt_soft_plus <= 20.0f) {
+            dt_soft_plus = log1pf(exp(dt_soft_plus));
+        }
+        float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
+        float sumf = 0.0f;
+#pragma unroll
+        for (int j = 0; j < N; j++) {
+            float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
+                          (B_block[i * stride_B + j] * x_dt);
+            sumf += state * C_block[i * stride_C + j];
+            if (i == L - 1) {
+                s_block[tid * stride_s + j] = state;
+            } else {
+                smem_s0[tid * stride_ss0 + j] = state;
+            }
+        }
+        __syncthreads();
+        y_block[i * stride_y + tid] = sumf;
+    }
+}
+
+static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
+                              const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
+                              const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
+                              const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
+                              const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
+                              float * dst, const int N, const int D, const int L, const int B, cudaStream_t stream) {
+    const int threads = 128;
+    // todo: consider D cannot be divided,does this situation exist?
+    GGML_ASSERT(D % threads == 0);
+    const dim3 blocks(B, (D + threads - 1) / threads, 1);
+    const int  smem_size = (threads * (N + 1) * 2) * sizeof(float);
+    if (N == 16) {
+        ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
+            src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
+            src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, D, L, B);
+    } else {
+        GGML_ABORT("doesn't support N!=16.");
+    }
+}
+
+void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];  // s
+    const struct ggml_tensor * src1 = dst->src[1];  // x
+    const struct ggml_tensor * src2 = dst->src[2];  // dt
+    const struct ggml_tensor * src3 = dst->src[3];  // A
+    const struct ggml_tensor * src4 = dst->src[4];  // B
+    const struct ggml_tensor * src5 = dst->src[5];  // C
+
+    //   const int64_t d_state = src0->ne[0];
+    //   const int64_t d_inner = src0->ne[1];
+    //   const int64_t l = src1->ne[1];
+    //   const int64_t b = src0->ne[2];
+
+    const int64_t nc  = src0->ne[0];  // d_state
+    const int64_t nr  = src0->ne[1];  // d_inner
+    const int64_t n_t = src1->ne[1];  // number of tokens per sequence
+    const int64_t n_s = src0->ne[2];  // number of sequences in the batch
+
+    GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+    GGML_ASSERT(src2->nb[0] == sizeof(float));
+    GGML_ASSERT(src3->nb[0] == sizeof(float));
+    GGML_ASSERT(src4->nb[0] == sizeof(float));
+    GGML_ASSERT(src5->nb[0] == sizeof(float));
+    // required for the dot product between s and C
+    GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
+    // required for per-sequence offsets for states
+    GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
+    // required to get correct offset for state destination (i.e. src1->nb[3])
+    GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    const float * src2_d = (const float *) src2->data;
+    const float * src3_d = (const float *) src3->data;
+    const float * src4_d = (const float *) src4->data;
+    const float * src5_d = (const float *) src5->data;
+    float *       dst_d  = (float *) dst->data;
+    cudaStream_t  stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
+                      src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
+                      src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
+}
diff --git a/ggml/src/ggml-cuda/ssm-scan.cuh b/ggml/src/ggml-cuda/ssm-scan.cuh
new file mode 100644 (file)
index 0000000..ee078f5
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);