]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : fix SSM_SCAN for n_groups > 1 (#15625)
authorcompilade <redacted>
Thu, 28 Aug 2025 14:11:36 +0000 (10:11 -0400)
committerGitHub <redacted>
Thu, 28 Aug 2025 14:11:36 +0000 (10:11 -0400)
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cuda/ssm-scan.cu
ggml/src/ggml-metal/ggml-metal.metal

index 93330b43a9b84c72d4a96f4aab692d54bb9c1036..8c1f7948855ac5c7076f781da9a36fa0a233c207 100644 (file)
@@ -9003,8 +9003,7 @@ static void ggml_compute_forward_ssm_scan_f32(
     GGML_ASSERT(src4->nb[0] == sizeof(float));
     GGML_ASSERT(src5->nb[0] == sizeof(float));
     GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
-    // allows optimizing the modulo since n_group should be a power of 2
-    GGML_ASSERT((ng & -ng) == ng);
+    GGML_ASSERT(nh % ng == 0);
 
     // heads per thread
     const int dh = (nh + nth - 1)/nth;
@@ -9035,6 +9034,7 @@ static void ggml_compute_forward_ssm_scan_f32(
                     // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
                     const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
                     const float dA = expf(dt_soft_plus * A[h]);
+                    const int g = h / (nh / ng); // repeat_interleave
 
                     // dim
                     for (int i1 = 0; i1 < nr; ++i1) {
@@ -9057,8 +9057,8 @@ static void ggml_compute_forward_ssm_scan_f32(
                             // TODO: maybe unroll more?
                             for (int j = 0; j < 1; j++) {
                                 GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
-                                GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
-                                GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
+                                GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
+                                GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
 
                                 t0 = GGML_F32_VEC_MUL(t0, adA);
                                 t1 = GGML_F32_VEC_MUL(t1, axdt);
@@ -9090,8 +9090,8 @@ static void ggml_compute_forward_ssm_scan_f32(
                         for (int i = 0; i < np; i += GGML_F32_STEP) {
                             for (int j = 0; j < GGML_F32_ARR; j++) {
                                 ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
-                                ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
-                                az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
+                                ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
+                                az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
 
                                 ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
                                 ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
@@ -9113,7 +9113,7 @@ static void ggml_compute_forward_ssm_scan_f32(
                         // d_state
                         for (int i0 = np; i0 < nc; ++i0) {
                             const int i = i0 + ii*nc;
-                            const int ig = i0 + (h & (ng - 1))*nc;
+                            const int ig = i0 + g*nc;
                             // state = prev_state * dA + dB * x
                             const float state = (s0[i] * dA) + (B[ig] * x_dt);
                             // y = rowwise_dotprod(state, C)
@@ -9130,6 +9130,7 @@ static void ggml_compute_forward_ssm_scan_f32(
                 for (int h = ih0; h < ih1; ++h) {
                     // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
                     const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+                    const int g = h / (nh / ng); // repeat_interleave
 
                     // dim
                     for (int i1 = 0; i1 < nr; ++i1) {
@@ -9144,8 +9145,8 @@ static void ggml_compute_forward_ssm_scan_f32(
                         // TODO: what happens when (d_state % svcntw()) != 0?
                         for (int64_t k = 0; k < nc; k += svcntw()) {
                             svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
-                            svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
-                            svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
+                            svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
+                            svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
                             svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
 
                             svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
@@ -9165,7 +9166,7 @@ static void ggml_compute_forward_ssm_scan_f32(
                         // d_state
                         for (int i0 = 0; i0 < nc; ++i0) {
                             const int i = i0 + ii*nc;
-                            const int ig = i0 + (h & (ng - 1))*nc;
+                            const int ig = i0 + g*nc;
                             // state = prev_state * dA + dB * x
                             const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
                             // y = rowwise_dotprod(state, C)
index dc9a7d58d057c14f6090081e3a872a45a11328c8..6b424381df5a7311d78b450195e5cb475f787299 100644 (file)
@@ -129,7 +129,7 @@ __global__ void __launch_bounds__(d_state, 1)
     const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
     const int seq_idx = blockIdx.y;
 
-    const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
+    const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
 
     const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
     const float * x_block  = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
index fa80d6e405978dd47a271890343c2a804574327b..4fa16c4a553d290d6a0f6bc368686eb596f13f96 100644 (file)
@@ -1983,14 +1983,15 @@ kernel void kernel_ssm_scan_f32(
     device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
     device       float * s_buff  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
     const int64_t i = i0 + i1*nc;
+    const int64_t g = ir / (nh / ng); // repeat_interleave
     float s0 = s0_buff[i];
     float s  = s_buff[i];
 
         device const float * A        = (device const float *) ((device const char *) src3 + ir*args.nb31);
         device const float * x_block  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
         device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
-        device const float * B_block  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
-        device const float * C_block  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
+        device const float * B_block  = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
+        device const float * C_block  = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
         device       float * y_block  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
 
     for (int64_t i2 = 0; i2 < n_t; ++i2) {
@@ -2098,14 +2099,15 @@ kernel void kernel_ssm_scan_f32_group(
     device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
     device       float * s_buff  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
     const int64_t i = i0 + i1*nc;
+    const int64_t g = ir / (nh / ng); // repeat_interleave
     float s0 = s0_buff[i];
     float s  = s_buff[i];
 
     device const float * A        = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
     device const float * x_block  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
     device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
-    device const float * B_block  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
-    device const float * C_block  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
+    device const float * B_block  = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
+    device const float * C_block  = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
     device       float * y_block  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
 
     for (int64_t i2 = 0; i2 < n_t; ++i2) {