]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : remove ggml_flash_attn and ggml_flash_ff (llama/7463)
authorGeorgi Gerganov <redacted>
Thu, 23 May 2024 07:00:44 +0000 (10:00 +0300)
committerGeorgi Gerganov <redacted>
Tue, 28 May 2024 11:41:08 +0000 (14:41 +0300)
ggml-ci

include/ggml/ggml.h
src/ggml.c
tests/test-grad0.cpp

index 08835042c0bfdbaa0f1f003760d80612c4a4fcfd..be81e0c52316bed34c719cf5bdf108b3b06947c0 100644 (file)
@@ -481,9 +481,7 @@ extern "C" {
         GGML_OP_ARGSORT,
         GGML_OP_LEAKY_RELU,
 
-        GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_ATTN_EXT,
-        GGML_OP_FLASH_FF,
         GGML_OP_FLASH_ATTN_BACK,
         GGML_OP_SSM_CONV,
         GGML_OP_SSM_SCAN,
@@ -1761,13 +1759,6 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   k);
 
-    GGML_API struct ggml_tensor * ggml_flash_attn(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * q,
-            struct ggml_tensor  * k,
-            struct ggml_tensor  * v,
-            bool                  masked);
-
 #define GGML_KQ_MASK_PAD 32
 
     // q:    [n_embd, n_batch,     n_head,    1]
@@ -1788,6 +1779,7 @@ extern "C" {
             struct ggml_tensor * a,
             enum ggml_prec       prec);
 
+    // TODO: needs to be adapted to ggml_flash_attn_ext
     GGML_API struct ggml_tensor * ggml_flash_attn_back(
            struct ggml_context * ctx,
            struct ggml_tensor  * q,
@@ -1796,14 +1788,6 @@ extern "C" {
            struct ggml_tensor  * d,
            bool                  masked);
 
-    GGML_API struct ggml_tensor * ggml_flash_ff(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b0,
-            struct ggml_tensor  * b1,
-            struct ggml_tensor  * c0,
-            struct ggml_tensor  * c1);
-
     GGML_API struct ggml_tensor * ggml_ssm_conv(
             struct ggml_context * ctx,
             struct ggml_tensor  * s,
index 673c47748e24691fa7a64fd44591ac89655d13c4..9e72b7a765dbae38e6765f31051404ab103d8958 100644 (file)
@@ -2670,9 +2670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ARGSORT",
     "LEAKY_RELU",
 
-    "FLASH_ATTN",
     "FLASH_ATTN_EXT",
-    "FLASH_FF",
     "FLASH_ATTN_BACK",
     "SSM_CONV",
     "SSM_SCAN",
@@ -2698,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -2760,9 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "argsort(x)",
     "leaky_relu(x)",
 
-    "flash_attn(x)",
     "flash_attn_ext(x)",
-    "flash_ff(x)",
     "flash_attn_back(x)",
     "ssm_conv(x)",
     "ssm_scan(x)",
@@ -2788,7 +2784,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -6948,38 +6944,6 @@ struct ggml_tensor * ggml_top_k(
     return result;
 }
 
-// ggml_flash_attn
-
-struct ggml_tensor * ggml_flash_attn(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * q,
-        struct ggml_tensor  * k,
-        struct ggml_tensor  * v,
-        bool                  masked) {
-    GGML_ASSERT(ggml_can_mul_mat(k, q));
-    // TODO: check if vT can be multiplied by (k*qT)
-
-    bool is_node = false;
-
-    if (q->grad || k->grad || v->grad) {
-        is_node = true;
-    }
-
-    //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
-
-    int32_t t = masked ? 1 : 0;
-    ggml_set_op_params(result, &t, sizeof(t));
-
-    result->op   = GGML_OP_FLASH_ATTN;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = q;
-    result->src[1] = k;
-    result->src[2] = v;
-
-    return result;
-}
-
 // ggml_flash_attn_ext
 
 struct ggml_tensor * ggml_flash_attn_ext(
@@ -7039,38 +7003,6 @@ void ggml_flash_attn_ext_set_prec(
     ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
 }
 
-// ggml_flash_ff
-
-struct ggml_tensor * ggml_flash_ff(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b0,
-        struct ggml_tensor  * b1,
-        struct ggml_tensor  * c0,
-        struct ggml_tensor  * c1) {
-    GGML_ASSERT(ggml_can_mul_mat(b0, a));
-    // TODO: more checks
-
-    bool is_node = false;
-
-    if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
-        is_node = true;
-    }
-
-    //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
-
-    result->op   = GGML_OP_FLASH_FF;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-    result->src[1] = b0;
-    result->src[2] = b1;
-    result->src[3] = c0;
-    result->src[4] = c1;
-
-    return result;
-}
-
 // ggml_flash_attn_back
 
 struct ggml_tensor * ggml_flash_attn_back(
@@ -7080,6 +7012,8 @@ struct ggml_tensor * ggml_flash_attn_back(
         struct ggml_tensor  * v,
         struct ggml_tensor  * d,
         bool                  masked) {
+    GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
+
     GGML_ASSERT(ggml_can_mul_mat(k, q));
     // TODO: check if vT can be multiplied by (k*qT)
 
@@ -15709,400 +15643,6 @@ static void ggml_compute_forward_argsort(
     }
 }
 
-// ggml_compute_forward_flash_attn
-
-static void ggml_compute_forward_flash_attn_f32(
-        const struct ggml_compute_params * params,
-        const bool masked,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * q = dst->src[0];
-    const struct ggml_tensor * k = dst->src[1];
-    const struct ggml_tensor * v = dst->src[2];
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
-    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
-    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t D = neq0;
-    const int64_t N = neq1;
-    const int64_t P = nek1 - N;
-    const int64_t M = P + N;
-
-    const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
-
-    GGML_ASSERT(ne0 == D);
-    GGML_ASSERT(ne1 == N);
-    GGML_ASSERT(P >= 0);
-
-    GGML_ASSERT(nbq0 == sizeof(float));
-    GGML_ASSERT(nbk0 == sizeof(float));
-    GGML_ASSERT(nbv0 == sizeof(float));
-
-    GGML_ASSERT(neq0 == D);
-    GGML_ASSERT(nek0 == D);
-    GGML_ASSERT(nev1 == D);
-
-    GGML_ASSERT(neq1 == N);
-    GGML_ASSERT(nek1 == N + P);
-    GGML_ASSERT(nev1 == D);
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
-    // parallelize by q rows using ggml_vec_dot_f32
-
-    // total rows in q
-    const int nr = neq1*neq2*neq3;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    const float scale = 1.0f/sqrtf(D);
-
-    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // q indices
-        const int iq3 = ir/(neq2*neq1);
-        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
-        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
-
-        float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
-
-        for (int i = M; i < Mup; ++i) {
-            S[i] = -INFINITY;
-        }
-
-        const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
-        for (int64_t ic = 0; ic < masked_begin; ++ic) {
-            // k indices
-            const int ik3 = iq3;
-            const int ik2 = iq2 % nek2;
-            const int ik1 = ic;
-
-            // S indices
-            const int i1 = ik1;
-
-            ggml_vec_dot_f32(neq0,
-                    S + i1, 0,
-                    (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
-                    (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
-        }
-
-        // scale
-        ggml_vec_scale_f32(masked_begin, S, scale);
-
-        for (int64_t i = masked_begin; i < M; i++) {
-            S[i] = -INFINITY;
-        }
-
-        // softmax
-        // exclude known -INF S[..] values from max and loop
-        // dont forget to set their SW values to zero
-        {
-            float max = -INFINITY;
-            ggml_vec_max_f32(masked_begin, &max, S);
-
-            ggml_float sum = 0.0;
-            {
-#ifdef GGML_SOFT_MAX_ACCELERATE
-                max = -max;
-                vDSP_vsadd(S, 1, &max, S, 1, Mup);
-                vvexpf(S, S, &Mup);
-                ggml_vec_sum_f32(Mup, &sum, S);
-#else
-                sum = ggml_vec_soft_max_f32(Mup, S, S, max);
-#endif
-            }
-
-            assert(sum > 0.0);
-
-            sum = 1.0/sum;
-            ggml_vec_scale_f32(masked_begin, S, sum);
-
-#ifndef NDEBUG
-            for (int i = 0; i < masked_begin; ++i) {
-                assert(!isnan(S[i]));
-                assert(!isinf(S[i]));
-            }
-#endif
-        }
-
-        for (int64_t ic = 0; ic < nev1; ++ic) {
-            // dst indices
-            const int i1 = iq1;
-            const int i2 = iq2;
-            const int i3 = iq3;
-
-            // v indices
-            const int iv2 = iq2 % nev2;
-            const int iv3 = iq3;
-
-            ggml_vec_dot_f32(masked_begin,
-                    (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2   + i3*nb3)), 0,
-                    (float *) ((char *) v->data   + (         ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
-                    S, 0, 1);
-        }
-    }
-}
-
-static void ggml_compute_forward_flash_attn_f16(
-        const struct ggml_compute_params * params,
-        const bool masked,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * q = dst->src[0];
-    const struct ggml_tensor * k = dst->src[1];
-    const struct ggml_tensor * v = dst->src[2];
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
-    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
-    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t D = neq0;
-    const int64_t N = neq1;
-    const int64_t P = nek1 - N;
-    const int64_t M = P + N;
-
-    const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
-
-    GGML_ASSERT(ne0 == D);
-    GGML_ASSERT(ne1 == N);
-    GGML_ASSERT(P >= 0);
-
-    GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
-
-    GGML_ASSERT(neq0 == D);
-    GGML_ASSERT(nek0 == D);
-    GGML_ASSERT(nev1 == D);
-
-    GGML_ASSERT(neq1 == N);
-    GGML_ASSERT(nek1 == N + P);
-    GGML_ASSERT(nev1 == D);
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
-    // parallelize by q rows using ggml_vec_dot_f32
-
-    // total rows in q
-    const int nr = neq1*neq2*neq3;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    const float scale = 1.0f/sqrtf(D);
-
-    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // q indices
-        const int iq3 = ir/(neq2*neq1);
-        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
-        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
-
-        float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
-
-        for (int i = M; i < Mup; ++i) {
-            S[i] = -INFINITY;
-        }
-
-        if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
-            for (int64_t ic = 0; ic < nek1; ++ic) {
-                // k indices
-                const int ik3 = iq3;
-                const int ik2 = iq2 % nek2;
-                const int ik1 = ic;
-
-                // S indices
-                const int i1 = ik1;
-
-                ggml_vec_dot_f16(neq0,
-                        S + i1, 0,
-                        (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
-                        (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
-            }
-        } else {
-            for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
-                // k indices
-                const int ik3 = iq3;
-                const int ik2 = iq2 % nek2;
-                const int ik1 = ic;
-
-                // S indices
-                const int i1 = ik1;
-
-                ggml_vec_dot_f16_unroll(neq0, nbk1,
-                        S + i1,
-                        ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
-                        (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
-            }
-        }
-
-        // scale
-        ggml_vec_scale_f32(nek1, S, scale);
-
-        if (masked) {
-            for (int64_t i = P; i < M; i++) {
-                if (i > P + iq1) {
-                    S[i] = -INFINITY;
-                }
-            }
-        }
-
-        // softmax
-        // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
-        // dont forget to set their S values to zero
-        {
-            float max = -INFINITY;
-            ggml_vec_max_f32(M, &max, S);
-
-            ggml_float sum = 0.0;
-            {
-#ifdef GGML_SOFT_MAX_ACCELERATE
-                max = -max;
-                vDSP_vsadd(S, 1, &max, S, 1, Mup);
-                vvexpf(S, S, &Mup);
-                ggml_vec_sum_f32(Mup, &sum, S);
-#else
-                sum = ggml_vec_soft_max_f32(Mup, S, S, max);
-#endif
-            }
-
-            assert(sum > 0.0);
-
-            sum = 1.0/sum;
-            ggml_vec_scale_f32(M, S, sum);
-
-#ifndef NDEBUG
-            for (int i = 0; i < M; ++i) {
-                assert(!isnan(S[i]));
-                assert(!isinf(S[i]));
-            }
-#endif
-        }
-
-        ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
-
-        for (int64_t i = 0; i < M; i++) {
-            S16[i] = GGML_FP32_TO_FP16(S[i]);
-        }
-
-        // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
-        if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
-            for (int64_t ic = 0; ic < nev1; ++ic) {
-                // dst indices
-                const int i1 = iq1;
-                const int i2 = iq2;
-                const int i3 = iq3;
-
-                // v indices
-                const int iv2 = iq2 % nev2;
-                const int iv3 = iq3;
-
-                ggml_vec_dot_f16(nev0,
-                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2   + i3*nb3)), 0,
-                        (ggml_fp16_t *) ((char *) v->data   + (         ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
-                        S16, 0, 1);
-            }
-        } else {
-            for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
-                // dst indices
-                const int i1 = iq1;
-                const int i2 = iq2;
-                const int i3 = iq3;
-
-                // v indices
-                const int iv2 = iq2 % nev2;
-                const int iv3 = iq3;
-
-                ggml_vec_dot_f16_unroll(nev0, nbv1,
-                        (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2   + i3*nb3)),
-                        ((char *)             v->data + (         ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
-                        S16);
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_flash_attn(
-        const struct ggml_compute_params * params,
-        const bool masked,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * q = dst->src[0];
-
-    switch (q->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_flash_attn_f16(params, masked, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_flash_attn_f32(params, masked, dst);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
 // ggml_compute_forward_flash_attn_ext
 
 static void ggml_compute_forward_flash_attn_ext_f16(
@@ -16336,165 +15876,6 @@ static void ggml_compute_forward_flash_attn_ext(
     }
 }
 
-// ggml_compute_forward_flash_ff
-
-static void ggml_compute_forward_flash_ff_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * a = dst->src[0];  // F16
-    const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
-    const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
-    const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
-    const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_LOCALS(int64_t, nea,  a,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nba,  a,   nb)
-    GGML_TENSOR_LOCALS(int64_t, neb0, b0,  ne)
-    GGML_TENSOR_LOCALS(size_t,  nbb0, b0,  nb)
-    GGML_TENSOR_LOCALS(int64_t, neb1, b1,  ne)
-    GGML_TENSOR_LOCALS(size_t,  nbb1, b1,  nb)
-    GGML_TENSOR_LOCALS(int64_t, nec0, c0,  ne)
-    GGML_TENSOR_LOCALS(size_t,  nbc0, c0,  nb)
-    GGML_TENSOR_LOCALS(int64_t, nec1, c1,  ne)
-    GGML_TENSOR_LOCALS(size_t,  nbc1, c1,  nb)
-    GGML_TENSOR_LOCALS(int64_t, ne,   dst, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb,   dst, nb)
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t D = nea0;
-    //const int64_t N = nea1;
-    const int64_t M = neb01;
-
-    GGML_ASSERT(ne0 == nea0);
-    GGML_ASSERT(ne1 == nea1);
-    GGML_ASSERT(ne2 == nea2);
-
-    GGML_ASSERT(nba0  == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nbb10 == sizeof(float));
-    GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nbc10 == sizeof(float));
-
-    GGML_ASSERT(neb00 == D);
-    GGML_ASSERT(neb01 == M);
-    GGML_ASSERT(neb10 == M);
-    GGML_ASSERT(neb11 == 1);
-
-    GGML_ASSERT(nec00 == M);
-    GGML_ASSERT(nec01 == D);
-    GGML_ASSERT(nec10 == D);
-    GGML_ASSERT(nec11 == 1);
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
-    // parallelize by a rows using ggml_vec_dot_f32
-
-    // total rows in a
-    const int nr = nea1*nea2*nea3;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // a indices
-        const int ia3 = ir/(nea2*nea1);
-        const int ia2 = (ir - ia3*nea2*nea1)/nea1;
-        const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
-
-        float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
-
-        for (int64_t ic = 0; ic < neb01; ++ic) {
-            // b0 indices
-            const int ib03 = ia3;
-            const int ib02 = ia2;
-            const int ib01 = ic;
-
-            // S indices
-            const int i1 = ib01;
-
-            ggml_vec_dot_f16(nea0,
-                    S + i1, 0,
-                    (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
-                    (ggml_fp16_t *) ((char *)  a->data + ( ia1*nba1  +  ia2*nba2  +  ia3*nba3)), 0, 1);
-        }
-
-        ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
-        //ggml_vec_gelu_f32(neb01, S, S);
-
-        ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
-
-        for (int64_t i = 0; i < M; i++) {
-            S16[i] = GGML_FP32_TO_FP16(S[i]);
-        }
-
-        ggml_vec_gelu_f16(neb01, S16, S16);
-
-        {
-            // dst indices
-            const int i1 = ia1;
-            const int i2 = ia2;
-            const int i3 = ia3;
-
-            for (int64_t ic = 0; ic < nec01; ++ic) {
-
-                ggml_vec_dot_f16(neb01,
-                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1   + i2*nb2   + i3*nb3)), 0,
-                        (ggml_fp16_t *) ((char *) c0->data  + (         ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
-                        S16, 0, 1);
-            }
-
-            ggml_vec_add_f32(nec01,
-                    (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
-                    (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
-                    (float *) c1->data);
-        }
-    }
-}
-
-static void ggml_compute_forward_flash_ff(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * b0 = dst->src[1];
-
-    switch (b0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_flash_ff_f16(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                GGML_ASSERT(false); // TODO
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
 // ggml_compute_forward_flash_attn_back
 
 static void ggml_compute_forward_flash_attn_back_f32(
@@ -18065,21 +17446,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_leaky_relu(params, tensor);
             } break;
-        case GGML_OP_FLASH_ATTN:
-            {
-                const int32_t t = ggml_get_op_params_i32(tensor, 0);
-                GGML_ASSERT(t == 0 || t == 1);
-                const bool masked = t != 0;
-                ggml_compute_forward_flash_attn(params, masked, tensor);
-            } break;
         case GGML_OP_FLASH_ATTN_EXT:
             {
                 ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
             } break;
-        case GGML_OP_FLASH_FF:
-            {
-                ggml_compute_forward_flash_ff(params, tensor);
-            } break;
         case GGML_OP_FLASH_ATTN_BACK:
             {
                 int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -19086,7 +18456,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
-        case GGML_OP_FLASH_ATTN:
         case GGML_OP_FLASH_ATTN_EXT:
             {
                 struct ggml_tensor * flash_grad = NULL;
@@ -19140,10 +18509,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             zero_table);
                 }
             } break;
-        case GGML_OP_FLASH_FF:
-            {
-                GGML_ASSERT(false); // not supported
-            } break;
         case GGML_OP_FLASH_ATTN_BACK:
             {
                 GGML_ASSERT(false); // not supported
@@ -19830,15 +19195,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
             {
                 n_tasks = n_threads;
             } break;
-        case GGML_OP_FLASH_ATTN:
         case GGML_OP_FLASH_ATTN_EXT:
             {
                 n_tasks = n_threads;
             } break;
-        case GGML_OP_FLASH_FF:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_FLASH_ATTN_BACK:
             {
                 n_tasks = n_threads;
@@ -20235,40 +19595,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
                     cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
                     cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
                 } break;
-            case GGML_OP_FLASH_ATTN:
-                {
-                    const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
-
-                    if (node->src[1]->type == GGML_TYPE_F32) {
-                        cur  = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
-                        cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
-                    } else if (node->src[1]->type == GGML_TYPE_F16) {
-                        cur  = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
-                        cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
-                    } else if (node->src[1]->type == GGML_TYPE_BF16) {
-                        cur  = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
-                        cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
-                    }
-                } break;
             case GGML_OP_FLASH_ATTN_EXT:
                 {
                     const int64_t ne00 = node->src[0]->ne[0]; // D
 
                     cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
                 } break;
-            case GGML_OP_FLASH_FF:
-                {
-                    if (node->src[1]->type == GGML_TYPE_F32) {
-                        cur  = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
-                        cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
-                    } else if (node->src[1]->type == GGML_TYPE_F16) {
-                        cur  = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
-                        cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
-                    } else if (node->src[1]->type == GGML_TYPE_BF16) {
-                        cur  = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
-                        cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
-                    }
-                } break;
             case GGML_OP_FLASH_ATTN_BACK:
                 {
                     const int64_t    D = node->src[0]->ne[0];
index 8ff76c8910c49917518aa0213882a737bba9a30b..21ca43be3a963bdc5dba3fdba4074d5597f3d96a 100644 (file)
@@ -1515,90 +1515,50 @@ int main(int argc, const char ** argv) {
         }
 
         // flash_attn f32
-        {
-            srand(seed);
-            const int nargs = 3;
-
-            int64_t ne2[4];
-
-            get_random_dims(ne2, 4);
-            int64_t D = ne2[0];
-            int64_t N = ne2[1];
-            int64_t M = ne2[2] + N;
-            int64_t B = ne2[3];
-
-            for (int masked = 0; masked <= 1; ++masked) {
-                for (int ndims = 2; ndims <= 4; ++ndims) {
-                    int max_nrep = (ndims >= 3) ? 2 : 1;
-                    for (int nrep = 1; nrep < max_nrep; ++nrep) {
-                        int64_t neq[4] = { D, N, B*nrep, ne[3] };
-                        int64_t nek[4] = { D, M, B, ne[3] };
-                        int64_t nev[4] = { M, D, B, ne[3] };
-                        if (ndims == 2) {
-                            neq[2] = 1; neq[3] = 1;
-                            nek[2] = 1; nek[3] = 1;
-                            nev[2] = 1; nev[3] = 1;
-                        } else if (ndims == 3) {
-                            neq[3] = 1;
-                            nek[3] = 1;
-                            nev[3] = 1;
-                        }
-                        x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
-                        x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
-                        x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
-                        ggml_set_param(ctx0, x[0]);
-                        ggml_set_param(ctx0, x[1]);
-                        ggml_set_param(ctx0, x[2]);
-
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
-
-                        check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
-                    }
-                }
-            }
-        }
-
-        // flash_attn f16, not yet fully implemented
-        if(0)
-        {
-            srand(seed);
-            const int nargs = 3;
-
-            int64_t ne2[4];
-
-            get_random_dims(ne2, 4);
-            int64_t D = ne2[0];
-            int64_t N = ne2[1];
-            int64_t M = ne2[2] + N;
-            int64_t B = ne2[3];
-
-            for (int masked = 0; masked <= 1; ++masked) {
-                for (int ndims = 2; ndims <= 4; ++ndims) {
-                    int64_t neq[4] = { D, N, B, ne[3] };
-                    int64_t nek[4] = { D, M, B, ne[3] };
-                    int64_t nev[4] = { M, D, B, ne[3] };
-                    if (ndims == 2) {
-                        neq[2] = 1; neq[3] = 1;
-                        nek[2] = 1; nek[3] = 1;
-                        nev[2] = 1; nev[3] = 1;
-                    } else if (ndims == 3) {
-                        neq[3] = 1;
-                        nek[3] = 1;
-                        nev[3] = 1;
-                    }
-                    x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f);
-                    x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f);
-                    x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f);
-                    ggml_set_param(ctx0, x[0]);
-                    ggml_set_param(ctx0, x[1]);
-                    ggml_set_param(ctx0, x[2]);
-
-                    struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
+        // TODO: adapt to ggml_flash_attn_ext() changes
+        //{
+        //    srand(seed);
+        //    const int nargs = 3;
+
+        //    int64_t ne2[4];
+
+        //    get_random_dims(ne2, 4);
+        //    int64_t D = ne2[0];
+        //    int64_t N = ne2[1];
+        //    int64_t M = ne2[2] + N;
+        //    int64_t B = ne2[3];
+
+        //    for (int masked = 0; masked <= 1; ++masked) {
+        //        for (int ndims = 2; ndims <= 4; ++ndims) {
+        //            int max_nrep = (ndims >= 3) ? 2 : 1;
+        //            for (int nrep = 1; nrep < max_nrep; ++nrep) {
+        //                int64_t neq[4] = { D, N, B*nrep, ne[3] };
+        //                int64_t nek[4] = { D, M, B, ne[3] };
+        //                int64_t nev[4] = { M, D, B, ne[3] };
+        //                if (ndims == 2) {
+        //                    neq[2] = 1; neq[3] = 1;
+        //                    nek[2] = 1; nek[3] = 1;
+        //                    nev[2] = 1; nev[3] = 1;
+        //                } else if (ndims == 3) {
+        //                    neq[3] = 1;
+        //                    nek[3] = 1;
+        //                    nev[3] = 1;
+        //                }
+        //                x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
+        //                x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
+        //                x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
+        //                ggml_set_param(ctx0, x[0]);
+        //                ggml_set_param(ctx0, x[1]);
+        //                ggml_set_param(ctx0, x[2]);
+
+        //                struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
+
+        //                check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
+        //            }
+        //        }
+        //    }
+        //}
 
-                    check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
-                }
-            }
-        }
         ggml_free(ctx0);
     }