]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : add softcap fusion (llama/14907)
authorSigbjørn Skjæret <redacted>
Tue, 29 Jul 2025 12:22:03 +0000 (14:22 +0200)
committerGeorgi Gerganov <redacted>
Sat, 2 Aug 2025 14:51:21 +0000 (17:51 +0300)
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/softcap.cu [new file with mode: 0644]
src/ggml-cuda/softcap.cuh [new file with mode: 0644]
tests/test-backend-ops.cpp

index 1f785796014bd90392445fb27e262fcaa366e2c6..819c2be78ec12fabe5cb6bfadecb0a7fc0ae0e9d 100644 (file)
@@ -33,6 +33,7 @@
 #include "ggml-cuda/rope.cuh"
 #include "ggml-cuda/roll.cuh"
 #include "ggml-cuda/scale.cuh"
+#include "ggml-cuda/softcap.cuh"
 #include "ggml-cuda/softmax.cuh"
 #include "ggml-cuda/ssm-conv.cuh"
 #include "ggml-cuda/ssm-scan.cuh"
@@ -2770,7 +2771,12 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
 }
 #endif
 
-static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
+static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
+#ifndef NDEBUG
+    const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
+    GGML_ASSERT(unary_ops.size() == num_unary);
+#endif
+
     if (!ggml_can_fuse(cgraph, node_idx, ops)) {
         return false;
     }
@@ -2798,9 +2804,32 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
             return false;
         }
+
+        return true;
     }
 
-    return true;
+    if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
+     && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
+        const ggml_tensor *scale  = cgraph->nodes[node_idx];
+        const ggml_tensor *tanh   = cgraph->nodes[node_idx+1];
+        const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
+
+        GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
+        GGML_ASSERT(scale->type == GGML_TYPE_F32);
+
+        if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
+            return false;
+        }
+
+        // Check for bias
+        if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
+            return false;
+        }
+
+        return true;
+    }
+
+    return false;
 }
 
 static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
@@ -2821,10 +2850,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                 }
 
                 static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
-                if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
-                    ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
-                    i++;
-                    continue;
+                if (!disable_fusion) {
+                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
+                        ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
+                        i++;
+                        continue;
+                    }
+
+                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
+                        i += 2;
+                        ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
+                        continue;
+                    }
                 }
 #ifndef NDEBUG
                 assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
diff --git a/src/ggml-cuda/softcap.cu b/src/ggml-cuda/softcap.cu
new file mode 100644 (file)
index 0000000..40dfe45
--- /dev/null
@@ -0,0 +1,34 @@
+#include "softcap.cuh"
+
+static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = tanhf(scale * x[i]) * softcap;
+}
+
+static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE;
+    softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, scale, softcap, k);
+}
+
+// fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
+void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src) {
+    const ggml_tensor * src0 = src->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    float scale;
+    float softcap;
+    memcpy(&scale,   (float *) src->op_params + 0, sizeof(float));
+    memcpy(&softcap, (float *) dst->op_params + 0, sizeof(float));
+
+    softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream);
+}
diff --git a/src/ggml-cuda/softcap.cuh b/src/ggml-cuda/softcap.cuh
new file mode 100644 (file)
index 0000000..6d34fb2
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_SOFTCAP_BLOCK_SIZE 256
+
+void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src);
index 3cc318a9cf96004871fd2c6fe2697e4b7dc6cd98..479b3fad486848ede2c22cfe4fc66bdb74c78981 100644 (file)
@@ -2545,6 +2545,41 @@ struct test_scale : public test_case {
     }
 };
 
+// GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
+struct test_softcap : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    float softcap;
+
+    std::string op_desc(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return "SOFTCAP";
+    }
+
+    bool run_whole_graph() override { return true; }
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, softcap);
+    }
+
+    test_softcap(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            float softcap = 30.0f)
+        : type(type), ne(ne), softcap(softcap) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_scale(ctx, ggml_tanh(ctx, ggml_scale(ctx, a, 1.0f / softcap)), softcap);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
 // GGML_OP_SILU_BACK
 struct test_silu_back : public test_case {
     const ggml_type type;
@@ -5421,6 +5456,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_add1());
     test_cases.emplace_back(new test_scale());
     test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
+    test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
     test_cases.emplace_back(new test_silu_back());
 
     for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {