]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
model : add ASR support for LFM2-Audio-1.5B (conformer) (llama/18106)
authorXuan-Son Nguyen <redacted>
Thu, 18 Dec 2025 23:18:01 +0000 (00:18 +0100)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 15:52:09 +0000 (17:52 +0200)
* ASR with LFM2-Audio-1.5B

* Set rope_theta

* Fix comment

* Remove rope_theta setting

* Address PR feedback

* rename functions to conformer

* remove some redundant ggml_cont

* fix missing tensor

* add prefix "a." for conv tensors

* remove redundant reshape

* clean up

* add test model

---------

Co-authored-by: Tarek Dakhran <redacted>
ggml/src/ggml-cuda/ssm-conv.cu

index 41979733601d27a6bba66ddd3bc1f78a8fde0396..6d5ea704c65a25d7087b5ef968112a7638e163fb 100644 (file)
@@ -102,31 +102,25 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
     const int threads = 128;
     GGML_ASSERT(nr % threads == 0);
 
-    if (n_t <= 32) {
-        const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
-        if (nc == 4) {
-            ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
-                                                                     dst, dst_nb0, dst_nb1, dst_nb2, n_t);
-        } else if (nc == 3) {
-            ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
-                                                                     dst, dst_nb0, dst_nb1, dst_nb2, n_t);
+    auto launch_kernel = [&](auto NC) {
+        constexpr int kNC = decltype(NC)::value;
+        if (n_t <= 32) {
+            const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
+            ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
+                                                                       dst, dst_nb0, dst_nb1, dst_nb2, n_t);
         } else {
-            GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
-        }
-    } else {
-        if (nc == 4) {
-            const int64_t split_n_t = 32;
-            dim3          blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
-            ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
-                src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
-        } else if (nc == 3) {
             const int64_t split_n_t = 32;
             dim3          blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
-            ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
+            ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
                 src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
-        } else {
-            GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
         }
+    };
+
+    switch (nc) {
+        case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
+        case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
+        case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
+        default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
     }
 }