]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: use shared mem for ssm_conv (#20128)
authorAman Gupta <redacted>
Fri, 6 Mar 2026 15:09:59 +0000 (23:09 +0800)
committerGitHub <redacted>
Fri, 6 Mar 2026 15:09:59 +0000 (23:09 +0800)
* CUDA: use shared mem for ssm_conv

* fuse silu + ssm_conv

* fuse unary + mul

* enable for fp16

* formatting

Co-authored-by: Johannes Gäßler <redacted>
---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/ssm-conv.cu
ggml/src/ggml-cuda/ssm-conv.cuh
ggml/src/ggml-cuda/unary.cu
ggml/src/ggml-cuda/unary.cuh
tests/test-backend-ops.cpp

index b2dcaf42fc3a0356fd692bcf942550694a37956a..35015bc7f3b1ad81d1309650a87539e4f23e73f4 100644 (file)
@@ -3348,6 +3348,46 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph *                cgraph,
         return true;
     }
 
+    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY
+     && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
+        const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
+        const ggml_tensor * silu     = cgraph->nodes[node_idx+1];
+
+        if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
+            return false;
+        }
+
+        return true;
+    }
+
+    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
+     && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
+        const ggml_tensor * unary = cgraph->nodes[node_idx];
+        const ggml_tensor * mul   = cgraph->nodes[node_idx+1];
+
+        if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) {
+            return false;
+        }
+
+        if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
+            return false;
+        }
+
+        if (unary->type != mul->type) {
+            return false;
+        }
+
+        const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0];
+        if (other->type != unary->type) {
+            return false;
+        }
+        if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) {
+            return false;
+        }
+
+        return true;
+    }
+
     if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
      && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
         const ggml_tensor *scale  = cgraph->nodes[node_idx];
@@ -3836,6 +3876,20 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
                         continue;
                     }
 
+                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
+                        ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]);
+                        i++;
+                        continue;
+                    }
+
+                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
+                        ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
+                        ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
+                        ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]);
+                        i++;
+                        continue;
+                    }
+
                     if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
                         i += 2;
                         ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
index 6d5ea704c65a25d7087b5ef968112a7638e163fb..85e82b5a422961823d290155b8120cf431641233 100644 (file)
@@ -1,6 +1,7 @@
 #include "ssm-conv.cuh"
+#include "unary.cuh"
 
-template <size_t split_d_inner, size_t d_conv>
+template <bool apply_silu, size_t split_d_inner, size_t d_conv>
 static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
                                     const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
                                     float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
@@ -41,11 +42,11 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
         for (size_t j = 0; j < d_conv; j++) {
             sumf += x[(i + j) % d_conv] * w[j];
         }
-        y_block[i * stride_y + tid] = sumf;
+        y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
     }
 }
 
-template <size_t split_d_inner, size_t d_conv, int64_t split_n_t>
+template <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t>
 static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
                                                const int src0_nb0, const int src0_nb1, const int src0_nb2,
                                                const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
@@ -65,36 +66,46 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
     const int stride_w = src1_nb1 / sizeof(float);
     const int stride_y = dst_nb1 / sizeof(float);
 
-    float x[d_conv] = { 0.0f };
-    float w[d_conv] = { 0.0f };
+    const int64_t local_n_t = min(split_n_t, n_t - bidz * split_n_t);
+    const int     n_cols    = d_conv - 1 + split_n_t;
+
+    extern __shared__ float smem[];
 
+    constexpr int load_cols   = d_conv - 1 + split_n_t;
+    constexpr int total_elems = split_d_inner * load_cols;
+    int row = tid / load_cols;
+    int col = tid % load_cols;
 #pragma unroll
-    for (size_t j = 0; j < d_conv; j++) {
-        w[j] = w_block[tid * stride_w + j];
+    for (int idx = tid; idx < total_elems; idx += split_d_inner) {
+        if (row < (int)split_d_inner) {
+            smem[row * n_cols + col] = x_block[row * stride_x + col];
+        }
+
+        col += split_d_inner;
+        row += col / load_cols;
+        col  = col % load_cols;
     }
+    __syncthreads();
 
+    // Load weights into registers (done once, small)
+    float w[d_conv] = { 0.0f };
 #pragma unroll
-    for (int64_t i = 0; i < split_n_t; i++) {
-        if (bidz * split_n_t + i < n_t) {
-            float sumf = 0.0f;
-
-            if (i == 0) {
-                for (size_t j = 0; j < d_conv; j++) {
-                    x[j] = x_block[tid * stride_x + j];
-                }
-            } else {
-                x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
-            }
+    for (size_t j = 0; j < d_conv; j++) {
+        w[j] = w_block[tid * stride_w + j];
+    }
 
+    // Compute from shared memory
+    for (int64_t i = 0; i < local_n_t; i++) {
+        float sumf = 0.0f;
 #pragma unroll
-            for (size_t j = 0; j < d_conv; j++) {
-                sumf += x[(i + j) % d_conv] * w[j];
-            }
-            y_block[i * stride_y + tid] = sumf;
+        for (size_t j = 0; j < d_conv; j++) {
+            sumf += smem[tid * n_cols + i + j] * w[j];
         }
+        y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
     }
 }
 
+template <bool apply_silu>
 static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
                               const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
                               const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
@@ -106,12 +117,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
         constexpr int kNC = decltype(NC)::value;
         if (n_t <= 32) {
             const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
-            ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
+            ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
                                                                        dst, dst_nb0, dst_nb1, dst_nb2, n_t);
         } else {
             const int64_t split_n_t = 32;
             dim3          blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
-            ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
+            const size_t  smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float);
+            ssm_conv_long_token_f32<apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>(
                 src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
         }
     };
@@ -124,27 +136,36 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
     }
 }
 
-void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) {
     const struct ggml_tensor * src0 = dst->src[0];  // conv_x
     const struct ggml_tensor * src1 = dst->src[1];  // conv1d.weight
+    const bool fuse_silu = silu_dst != nullptr;
+
+    // When fusing, write to silu_dst (the node downstream references).
+    const struct ggml_tensor * out = fuse_silu ? silu_dst : dst;
 
     const int64_t nc  = src1->ne[0];                // d_conv
     const int64_t nr  = src0->ne[1];                // d_inner
-    const int64_t n_t = dst->ne[1];                 // tokens per sequence
-    const int64_t n_s = dst->ne[2];                 // number of sequences in the batch
+    const int64_t n_t = out->ne[1];                 // tokens per sequence
+    const int64_t n_s = out->ne[2];                 // number of sequences in the batch
 
-    GGML_ASSERT(dst->ne[0] == nr);
+    GGML_ASSERT(out->ne[0] == nr);
     GGML_ASSERT(src0->nb[0] == sizeof(float));
     GGML_ASSERT(src1->nb[0] == sizeof(float));
     GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
 
     const float * src0_d = (const float *) src0->data;
     const float * src1_d = (const float *) src1->data;
-    float *       dst_d  = (float *) dst->data;
+    float *       dst_d  = (float *) out->data;
     cudaStream_t  stream = ctx.stream();
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
-                      dst->nb[2], nc, nr, n_t, n_s, stream);
+    GGML_ASSERT(out->type == GGML_TYPE_F32);
+    if (fuse_silu) {
+        ssm_conv_f32_cuda<true>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
+                          out->nb[2], nc, nr, n_t, n_s, stream);
+    } else {
+        ssm_conv_f32_cuda<false>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
+                          out->nb[2], nc, nr, n_t, n_s, stream);
+    }
 }
index 8e6c1f00bfa03daf521694b24143f5d47c4019ae..f96a1cd2484bad1a22a450d77e6542f59c4a5725 100644 (file)
@@ -1,3 +1,3 @@
 #include "common.cuh"
 
-void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr);
index d4866067a4f166ff2019b6cb6ddd44255b66ac93..4ad30fa1f353a6a070b3a955ee0d729991fa713a 100644 (file)
@@ -560,3 +560,58 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
         leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
     }
 }
+
+/* fused unary + mul */
+
+template <float (*op)(float)>
+static void ggml_cuda_op_unary_mul_impl(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {
+    // unary_node: UNARY op applied to unary_node->src[0]
+    // mul_node:   MUL(a, b) where one of a/b is unary_node
+    // Output goes to mul_node->data
+
+    const ggml_tensor * unary_src = unary_node->src[0];  // input to the unary op
+    const ggml_tensor * other_src = (mul_node->src[0] == unary_node) ? mul_node->src[1] : mul_node->src[0];
+
+    GGML_ASSERT(ggml_is_contiguous_1(unary_src));
+    GGML_ASSERT(unary_src->nb[0] == ggml_element_size(unary_src));
+    GGML_ASSERT(ggml_is_contiguous_1(other_src));
+    GGML_ASSERT(other_src->nb[0] == ggml_element_size(other_src));
+    GGML_ASSERT(ggml_are_same_shape(unary_src, other_src));
+
+    GGML_ASSERT(unary_src->type == GGML_TYPE_F32 || unary_src->type == GGML_TYPE_F16);
+    GGML_ASSERT(unary_src->type == other_src->type);
+    GGML_ASSERT(unary_src->type == mul_node->type);
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t k  = ggml_nelements(mul_node);
+    const int64_t nc = unary_src->ne[0];
+    const int64_t unary_stride = unary_src->nb[1];
+    const int64_t other_stride = other_src->nb[1];
+
+    if (unary_src->type == GGML_TYPE_F16) {
+        unary_gated_cuda<op>((const half *) unary_src->data, (const half *) other_src->data,
+                             (half *) mul_node->data, k, nc,
+                             unary_stride / sizeof(half), other_stride / sizeof(half), stream);
+    } else {
+        unary_gated_cuda<op>((const float *) unary_src->data, (const float *) other_src->data,
+                             (float *) mul_node->data, k, nc,
+                             unary_stride / sizeof(float), other_stride / sizeof(float), stream);
+    }
+}
+
+void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {
+    switch (ggml_get_unary_op(unary_node)) {
+        case GGML_UNARY_OP_SILU:
+            ggml_cuda_op_unary_mul_impl<op_silu>(ctx, unary_node, mul_node);
+            break;
+        case GGML_UNARY_OP_SIGMOID:
+            ggml_cuda_op_unary_mul_impl<op_sigmoid>(ctx, unary_node, mul_node);
+            break;
+        case GGML_UNARY_OP_SOFTPLUS:
+            ggml_cuda_op_unary_mul_impl<op_softplus>(ctx, unary_node, mul_node);
+            break;
+        default:
+            GGML_ABORT("Unsupported unary op for fused unary+mul");
+    }
+}
index 609046e56941eab99495196a29e81de148912ccf..f1dd2183a6c7ffb0e04f8f30129c3c4068b6d18d 100644 (file)
@@ -89,6 +89,8 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
 void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node);
+
 __device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
     return x / (1.0f + expf(-x));
 }
index 7c6938d447bf6fd2b9aaf0df957dcc73b3e59619..c4b9540f4f40da49e135eadc876ddc8ed6b2c469 100644 (file)
@@ -7663,6 +7663,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
             test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
             test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
             test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));
+            // long token (n_t > 32, exercises the long_token kernel path)
+            test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
+            test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));
         }
     }