]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix FA when KV cache is not used (i.e. embeddings) (#12825)
authorGeorgi Gerganov <redacted>
Tue, 8 Apr 2025 16:54:51 +0000 (19:54 +0300)
committerGitHub <redacted>
Tue, 8 Apr 2025 16:54:51 +0000 (19:54 +0300)
* ggml : FA supports F32 V

* graph : cast KV to F16 when the KV cache is not used

ggml-ci

* server : add test that exercises embeddings with FA enabled

ggml-ci

examples/server/tests/unit/test_embedding.py
examples/server/tests/utils.py
examples/server_embd.py
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-metal/ggml-metal.m
src/llama-graph.cpp

index 8b0eb42b0926ff596327ecdbdf921b914a0e3026..0feb452ccfcd448af6695e2945840f5590da1615 100644 (file)
@@ -49,6 +49,26 @@ def test_embedding_multiple():
         assert len(d['embedding']) > 1
 
 
+def test_embedding_multiple_with_fa():
+    server = ServerPreset.bert_bge_small_with_fa()
+    server.pooling = 'last'
+    server.start()
+    # one of these should trigger the FA branch (i.e. context size % 256 == 0)
+    res = server.make_request("POST", "/v1/embeddings", data={
+        "input": [
+            "a "*253,
+            "b "*254,
+            "c "*255,
+            "d "*256,
+        ],
+    })
+    assert res.status_code == 200
+    assert len(res.body['data']) == 4
+    for d in res.body['data']:
+        assert 'embedding' in d
+        assert len(d['embedding']) > 1
+
+
 @pytest.mark.parametrize(
     "input,is_multi_prompt",
     [
index 30aa8660950a14c70ef9bfa7e698abb7d0588d36..4dc2062a8e5b9c18d3a8234163ea9e666b094ec0 100644 (file)
@@ -323,6 +323,21 @@ class ServerPreset:
         server.server_embeddings = True
         return server
 
+    @staticmethod
+    def bert_bge_small_with_fa() -> ServerProcess:
+        server = ServerProcess()
+        server.model_hf_repo = "ggml-org/models"
+        server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
+        server.model_alias = "bert-bge-small"
+        server.n_ctx = 1024
+        server.n_batch = 300
+        server.n_ubatch = 300
+        server.n_slots = 2
+        server.fa = True
+        server.seed = 42
+        server.server_embeddings = True
+        return server
+
     @staticmethod
     def tinyllama_infill() -> ServerProcess:
         server = ServerProcess()
index 0e34c6ceab9cabdcb04d4b562ca1d4589fd6a8d7..f8b0ffecd8f4718ac53894fbaf86789c8617bda4 100644 (file)
@@ -15,7 +15,7 @@ async def main():
     model_url = "http://127.0.0.1:6900"
     responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
         url= f"{model_url}/embedding",
-        json= {"content": str(0)*1024}
+        json= {"content": "a "*1022}
     ) for i in range(n)])
 
     for response in responses:
index 7a8d5ac6fd9d0d36879f7e936b0b9ee806f69ba4..f63656be54f5c3c69b604a019c4d2dd7f5c09f62 100644 (file)
@@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     ggml_vec_dot_t    const kq_vec_dot     = ggml_get_type_traits_cpu(k->type)->vec_dot;
     ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;
 
-    GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
-    GGML_ASSERT(v_to_float   && "fattn: unsupported V-type");
+    GGML_ASSERT((                            q_to_vec_dot) && "fattn: unsupported K-type");
+    GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float  ) && "fattn: unsupported V-type");
 
     // loop over n_batch and n_head
     for (int ir = ir0; ir < ir1; ++ir) {
@@ -6818,10 +6818,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
                     vs = expf(s - M);
                 }
 
-                v_to_float(v_data, V32, DV);
-
                 // V += v*expf(s - M)
-                ggml_vec_mad_f32(DV, VKQ32, V32, vs);
+                if (v_to_float) {
+                    v_to_float(v_data, V32, DV);
+                    ggml_vec_mad_f32(DV, VKQ32, V32, vs);
+                } else {
+                    // V is F32
+                    ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
+                }
             }
 
             S = S*ms + vs; // scale and increment sum with partial sum
index 456e1fd994c4041eaec76fa68e62661023860432..f226826020a5ad2bbe0093310eccaef59f7981fe 100644 (file)
@@ -1345,6 +1345,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_ARANGE:
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
+            if (op->src[0]->ne[0] == 32) {
+                // head size == 32 (e.g. bert-bge-small)
+                // TODO: not sure if it is worth adding kernels for this size
+                return false;
+            }
             if (op->src[1]->type != op->src[2]->type) {
                 return false;
             }
index c3469177e091cdc10fa0abd861ce9de226258707..cd955d63bc390729f7a9f4863264a804cd2c53e9 100644 (file)
@@ -1215,6 +1215,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
             v = ggml_transpose(ctx0, v);
         }
 
+        // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
+        if (k->type == GGML_TYPE_F32) {
+            k = ggml_cast(ctx0, k, GGML_TYPE_F16);
+        }
+
+        if (v->type == GGML_TYPE_F32) {
+            v = ggml_cast(ctx0, v, GGML_TYPE_F16);
+        }
+
         cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);