]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml: aarch64: Implement SVE F32 kernels for Mamba Sequential Scan Algorithm (llama...
authorVineel Abhinav <redacted>
Thu, 29 May 2025 09:18:43 +0000 (14:48 +0530)
committerGeorgi Gerganov <redacted>
Sun, 1 Jun 2025 12:14:44 +0000 (15:14 +0300)
* F32-Mamba-Seq_Scan-SVE

* Fix formatting

* ggml : missing space

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cpu/vec.h

index f4692f3cd5b90d8b2dd651e66e9b7f2be366c99a..d8de7531b0e5fac1982812a07754bcefee6a7c9c 100644 (file)
@@ -7633,39 +7633,83 @@ static void ggml_compute_forward_ssm_scan_f32(
     const int ir1 = MIN(ir0 + dr, nr);
     const int ir  = ir1 - ir0;
 
-    for (int i3 = 0; i3 < n_s; ++i3) {
-        for (int i2 = 0; i2 < n_t; ++i2) {
-            const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
-            const float * x  = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
-            const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
-            const float * A  = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
-            const float * B  = (const float *) ((const char *) src4->data +  i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
-            const float * C  = (const float *) ((const char *) src5->data +  i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
-                float * y  = (      float *) ((      char *) dst->data  + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
-                float * s  = (      float *) ((      char *) dst->data  + ir0*(src0->nb[1]) + i3*(src0->nb[2]) +     src1->nb[3]);  // {d_state, d_inner, n_s}
-
-            // use the output as the source for the next token-wise iterations
-            if (i2 > 0) { s0 = s; }
-
-            // d_inner
-            for (int i1 = 0; i1 < ir; ++i1) {
-                // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
-                float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
-                float x_dt = x[i1] * dt_soft_plus;
-                float sumf = 0.0f;
-                // d_state
-                for (int i0 = 0; i0 < nc; ++i0) {
-                    int i = i0 + i1*nc;
-                    // state = prev_state * dA + dB * x
-                    float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
-                    // y = rowwise_dotprod(state, C)
-                    sumf += state * C[i0];
-                    s[i] = state;
+    #ifdef __ARM_FEATURE_SVE
+        for (int i3 = 0; i3 < n_s; ++i3) {
+            for (int i2 = 0; i2 < n_t; ++i2) {
+                const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
+                const float * x  = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+                const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
+                const float * A  = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
+                const float * B  = (const float *) ((const char *) src4->data +  i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
+                const float * C  = (const float *) ((const char *) src5->data +  i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
+                    float * y  = (      float *) ((      char *) dst->data  + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+                    float * s  = (      float *) ((      char *) dst->data  + ir0*(src0->nb[1]) + i3*(src0->nb[2]) +     src1->nb[3]);  // {d_state, d_inner, n_s}
+
+                // use the output as the source for the next token-wise iterations
+                if (i2 > 0) { s0 = s; }
+
+                // d_inner
+                for (int i1 = 0; i1 < ir; ++i1) {
+                    float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
+                    float x_dt = x[i1] * dt_soft_plus;
+                    svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
+                    svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
+                    svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
+
+                    for (int64_t k = 0; k < nc; k += svcntw()) {
+                        svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
+                        svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
+                        svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
+                        svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
+
+                        svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
+                        t1 = exp_ps_sve(svptrue_b32(), t1);
+                        svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
+
+                        vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
+                        r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
+
+                        GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
+                    }
+                    y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
                 }
-                y[i1] = sumf;
             }
         }
-    }
+    #else
+        for (int i3 = 0; i3 < n_s; ++i3) {
+            for (int i2 = 0; i2 < n_t; ++i2) {
+                const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
+                const float * x  = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+                const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
+                const float * A  = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
+                const float * B  = (const float *) ((const char *) src4->data +  i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
+                const float * C  = (const float *) ((const char *) src5->data +  i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
+                    float * y  = (      float *) ((      char *) dst->data  + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+                    float * s  = (      float *) ((      char *) dst->data  + ir0*(src0->nb[1]) + i3*(src0->nb[2]) +     src1->nb[3]);  // {d_state, d_inner, n_s}
+
+                // use the output as the source for the next token-wise iterations
+                if (i2 > 0) { s0 = s; }
+
+                // d_inner
+                for (int i1 = 0; i1 < ir; ++i1) {
+                    // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
+                    float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
+                    float x_dt = x[i1] * dt_soft_plus;
+                    float sumf = 0.0f;
+                    // d_state
+                    for (int i0 = 0; i0 < nc; ++i0) {
+                        int i = i0 + i1*nc;
+                        // state = prev_state * dA + dB * x
+                        float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+                        // y = rowwise_dotprod(state, C)
+                        sumf += state * C[i0];
+                        s[i] = state;
+                    }
+                    y[i1] = sumf;
+                }
+            }
+        }
+    #endif
 }
 
 void ggml_compute_forward_ssm_scan(
index 163a274352cb7b09a761d33aa29dd3035f4cb561..09dbade2179fb386e30d42c578c391348566d45d 100644 (file)
@@ -647,6 +647,42 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
 #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
 #endif
 
+/* Below function was borrowed from the GitHub repository:
+https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */
+#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
+    inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {
+        // Constants
+        const svfloat32_t log2_e = svdup_n_f32(1.4426950409f);
+        const svfloat32_t ln2 = svdup_n_f32(0.6931473921f);
+        const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f);
+        const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1));
+        const svfloat32_t one = svdup_n_f32(1.0f);
+        const svfloat32_t inactive1 = svdup_n_f32(0.0f);
+        const svint32_t inactive2 = svdup_n_s32(0);
+
+        // Algorithm starts here
+        svfloat32_t t0 = svmul_f32_m(pg, src, log2_e);  // y = x * log2(e)
+        svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0);         // rount to int (float)
+        svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1);         // n
+
+        t1 = svsub_f32_m(pg, t0, t1);   // a = y - floor(y)
+        t1 = svadd_f32_m(pg, t1, one);  // b = a + 1
+
+        svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17);  // v = b >> 17 (u32)
+        svfloat32_t t4 = svexpa_f32(t3);                                   // c = fexpa(v)
+        t4 = svscale_f32_m(pg, t4, t2);                                    // fexpa(v) * 2^(n)
+
+        // and_(t2.d, t1.d, not_mask17.d)
+        svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17));
+        t5 = svsub_f32_m(pg, t1, t5);                // z
+        t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq);  // ln2 + half_ln2_sq * z
+        t0 = svmla_f32_m(pg, one, t5, t0);           // 1 + (ln2 * z) + (half_ln2_sq * z * z)
+        t0 = svmul_f32_m(pg, t0, t4);                // Final result
+
+        return t0;
+    }
+#endif
+
 #if defined(__ARM_NEON) && defined(__aarch64__)
 
 // adapted from arm limited optimized routine