]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
context : always use non-causal attention for encoder graphs (#12447)
authorGeorgi Gerganov <redacted>
Tue, 18 Mar 2025 11:05:49 +0000 (13:05 +0200)
committerGitHub <redacted>
Tue, 18 Mar 2025 11:05:49 +0000 (13:05 +0200)
* context : always use non-causal attention for encoder graphs

ggml-ci

* context : move the change to llama_context::encode()

ggml-ci

src/llama-context.cpp

index abb7e526f61711154bff3de32f1909a5735d3a88..42332acf1e39d7e64f88cdfd017c6664a38adbf2 100644 (file)
@@ -1057,6 +1057,13 @@ int llama_context::encode(llama_batch & inp_batch) {
     ggml_backend_sched_reset(sched.get());
     ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
 
+    const auto causal_attn_org = cparams.causal_attn;
+
+    // always use non-causal attention for encoder graphs
+    // TODO: this is a tmp solution until we have a proper way to support enc-dec models
+    //       ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
+    cparams.causal_attn = false;
+
     auto * gf = graph_init();
     auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
 
@@ -1064,6 +1071,8 @@ int llama_context::encode(llama_batch & inp_batch) {
 
     res->set_inputs(&ubatch);
 
+    cparams.causal_attn = causal_attn_org;
+
     const auto compute_status = graph_compute(gf, n_tokens > 1);
     switch (compute_status) {
         case GGML_STATUS_SUCCESS: