]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
llama : fix FA when KV cache is not used (i.e. embeddings) (llama/12825)
authorGeorgi Gerganov <redacted>
Tue, 8 Apr 2025 16:54:51 +0000 (19:54 +0300)
committerGeorgi Gerganov <redacted>
Thu, 24 Apr 2025 17:39:16 +0000 (20:39 +0300)
* ggml : FA supports F32 V

* graph : cast KV to F16 when the KV cache is not used

ggml-ci

* server : add test that exercises embeddings with FA enabled

ggml-ci

ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-metal/ggml-metal.m

index 9e86de5957341bec69d5f7b53e7b453fc4f075a5..6050147be70accd7fd25df343012ca4e54b48702 100644 (file)
@@ -6769,8 +6769,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     ggml_vec_dot_t    const kq_vec_dot     = ggml_get_type_traits_cpu(k->type)->vec_dot;
     ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;
 
-    GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
-    GGML_ASSERT(v_to_float   && "fattn: unsupported V-type");
+    GGML_ASSERT((                            q_to_vec_dot) && "fattn: unsupported K-type");
+    GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float  ) && "fattn: unsupported V-type");
 
     // loop over n_batch and n_head
     for (int ir = ir0; ir < ir1; ++ir) {
@@ -6866,10 +6866,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
                     vs = expf(s - M);
                 }
 
-                v_to_float(v_data, V32, DV);
-
                 // V += v*expf(s - M)
-                ggml_vec_mad_f32(DV, VKQ32, V32, vs);
+                if (v_to_float) {
+                    v_to_float(v_data, V32, DV);
+                    ggml_vec_mad_f32(DV, VKQ32, V32, vs);
+                } else {
+                    // V is F32
+                    ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
+                }
             }
 
             S = S*ms + vs; // scale and increment sum with partial sum
index 0c272e002538a5c3635e9e95c78caddab082016e..9f1c6c6ccc09ff3f3721c90b02ec37e389ca6a0e 100644 (file)
@@ -1346,6 +1346,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_ARANGE:
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
+            if (op->src[0]->ne[0] == 32) {
+                // head size == 32 (e.g. bert-bge-small)
+                // TODO: not sure if it is worth adding kernels for this size
+                return false;
+            }
             if (op->src[1]->type != op->src[2]->type) {
                 return false;
             }