]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : support Falcon-H1 state size for SSM_SCAN (llama/14602)
authorcompilade <redacted>
Thu, 10 Jul 2025 03:54:38 +0000 (23:54 -0400)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/ssm-scan.cu
tests/test-backend-ops.cpp

index da1e8f8f4e44302fabb663e201d6a5150d173912..72406f0af36225dd4f6a4c15d01bd888d88fa5a8 100644 (file)
@@ -3335,8 +3335,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SSM_SCAN: {
             if (op->src[3]->ne[0] == 1) {
                 // Mamba2
-                // (kernel only supports d_state == 128 && d_head % 16 == 0)
-                return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
+                // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
+                return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
             } else {
                 // Mamba
                 // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
index dc3b1a9a8cbf08bb44263ca63a42875f99717100..c9184398b422c7f3c1e1938692db24d40dbacb1f 100644 (file)
@@ -201,11 +201,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
                               const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
                               const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
                               cudaStream_t stream) {
-    const int threads = 128;
     // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
     if (src3_nb1 == sizeof(float)) {
         // Mamba-2
         if (d_state == 128) {
+            const int threads = 128;
             GGML_ASSERT(d_state % threads == 0);
             // NOTE: can be any power of two between 4 and 64
             const int splitH = 16;
@@ -215,10 +215,21 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
                     src0, src1, src2, src3, src4, src5, src6, dst,
                     src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
                     src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
+        } else if (d_state == 256) { // Falcon-H1
+            const int threads = 256;
+            // NOTE: can be any power of two between 8 and 64
+            const int splitH = 16;
+            GGML_ASSERT(head_dim % splitH == 0);
+            const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
+            ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                    src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
+                    src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
         } else {
-            GGML_ABORT("doesn't support d_state!=128.");
+            GGML_ABORT("doesn't support d_state!=(128 or 256).");
         }
     } else {
+        const int threads = 128;
         // Mamba-1
         GGML_ASSERT(n_head % threads == 0);
         GGML_ASSERT(head_dim == 1);
index 1d837b4322cfa7a5d21cd468bb1c0a5751d6f49d..4eeeb6e43a40027dbfad22b52270177f89dc2dae 100644 (file)
@@ -5069,6 +5069,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
 
     test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
     test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
+    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64,  8, 2, 32, 4)); // Falcon-H1
 
     test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
     test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));