]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix non-causal mask for gemma 3 (#12615)
authorXuan-Son Nguyen <redacted>
Sat, 29 Mar 2025 23:07:37 +0000 (00:07 +0100)
committerGitHub <redacted>
Sat, 29 Mar 2025 23:07:37 +0000 (00:07 +0100)
src/llama-context.cpp
src/llama-graph.cpp

index 9467c3a010db480edc7f68da509f70ad4001f62a..3479a8cca3d6408e9a354b726806dd4854c5fb0e 100644 (file)
@@ -1317,8 +1317,8 @@ int llama_context::decode(llama_batch & inp_batch) {
             n_outputs = n_outputs_new;
         }
 
-        // non-causal masks do not use the KV cache
-        if (hparams.causal_attn) {
+        // find KV slot
+        {
             kv_self_update();
 
             // if we have enough unused cells before the current head ->
index 0bd40174438cce54a077beafc8a4d4e2ce46a870..cec203df49268e215fc9a488e27964a179f2e8b8 100644 (file)
@@ -402,120 +402,86 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
 
 void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
     if (self_kq_mask || self_kq_mask_swa) {
-        // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
-        if (cparams.causal_attn) {
-            const int64_t n_kv         = kv_self->n;
-            const int64_t n_tokens     = ubatch->n_tokens;
-            const int64_t n_seq_tokens = ubatch->n_seq_tokens;
-            const int64_t n_seqs       = ubatch->n_seqs;
-
-            float * data     = nullptr;
-            float * data_swa = nullptr;
-
-            if (self_kq_mask) {
-                GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
-                data = (float *) self_kq_mask->data;
-            }
+        const int64_t n_kv         = kv_self->n;
+        const int64_t n_tokens     = ubatch->n_tokens;
+        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+        const int64_t n_seqs       = ubatch->n_seqs;
 
-            if (self_kq_mask_swa) {
-                GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
-                data_swa = (float *) self_kq_mask_swa->data;
-            }
+        float * data     = nullptr;
+        float * data_swa = nullptr;
 
-            // For causal attention, use only the previous KV cells
-            // of the correct sequence for each token of the ubatch.
-            // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
-            for (int h = 0; h < 1; ++h) {
-                for (int s = 0; s < n_seqs; ++s) {
-                    const llama_seq_id seq_id = ubatch->seq_id[s][0];
+        if (self_kq_mask) {
+            GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
+            data = (float *) self_kq_mask->data;
+        }
 
-                    for (int j = 0; j < n_seq_tokens; ++j) {
-                        const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
+        if (self_kq_mask_swa) {
+            GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
+            data_swa = (float *) self_kq_mask_swa->data;
+        }
 
-                        for (int i = 0; i < n_kv; ++i) {
-                            float f;
-                            if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
-                                f = -INFINITY;
+        // Use only the previous KV cells of the correct sequence for each token of the ubatch.
+        // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
+        // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
+        //   Causal mask:
+        //      xxx-------
+        //      xxxx------
+        //      xxxxx-----
+        //   Non-causal mask:
+        //      xxxxx-----
+        //      xxxxx-----
+        //      xxxxx-----
+        // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
+        for (int h = 0; h < 1; ++h) {
+            for (int s = 0; s < n_seqs; ++s) {
+                const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+                for (int j = 0; j < n_seq_tokens; ++j) {
+                    const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
+                    for (int i = 0; i < n_kv; ++i) {
+                        float f;
+                        // mask the token if:
+                        if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
+                            || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
+                        ) {
+                            f = -INFINITY;
+                        } else {
+                            if (hparams.use_alibi) {
+                                f = -std::abs(kv_self->cells[i].pos - pos);
                             } else {
-                                if (hparams.use_alibi) {
-                                    f = -std::abs(kv_self->cells[i].pos - pos);
-                                } else {
-                                    f = 0.0f;
-                                }
-                            }
-
-                            if (data) {
-                                data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+                                f = 0.0f;
                             }
+                        }
 
-                            // may need to cut off old tokens for sliding window
-                            if (data_swa) {
-                                if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
-                                    f = -INFINITY;
-                                }
-                                data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
-                            }
+                        if (data) {
+                            data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
                         }
-                    }
-                }
 
-                if (data) {
-                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                        for (int j = 0; j < n_kv; ++j) {
-                            data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                        // may need to cut off old tokens for sliding window
+                        if (data_swa) {
+                            if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
+                                f = -INFINITY;
+                            }
+                            data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
                         }
                     }
                 }
+            }
 
-                if (data_swa) {
-                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                        for (int j = 0; j < n_kv; ++j) {
-                            data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
-                        }
+            // mask padded tokens
+            if (data) {
+                for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                    for (int j = 0; j < n_kv; ++j) {
+                        data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
                     }
                 }
             }
-        } else {
-            const int64_t n_tokens     = ubatch->n_tokens;
-            const int64_t n_seq_tokens = ubatch->n_seq_tokens;
-            const int64_t n_seqs       = ubatch->n_seqs;
-            // when using kv cache, the mask needs to match the kv cache size
-            const int64_t n_stride     = n_tokens;
 
-            GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
-
-            float * data = (float *) self_kq_mask->data;
-
-            for (int h = 0; h < 1; ++h) {
-                for (int s1 = 0; s1 < n_seqs; ++s1) {
-                    const llama_seq_id seq_id = ubatch->seq_id[s1][0];
-
-                    for (int j = 0; j < n_seq_tokens; ++j) {
-                        const int32_t tj = s1*n_seq_tokens + j;
-
-                        for (int s0 = 0; s0 < n_seqs; ++s0) {
-                            for (int i = 0; i < n_seq_tokens; ++i) {
-                                const int32_t ti = s0*n_seq_tokens + i;
-                                float f = -INFINITY;
-
-                                for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
-                                    if (ubatch->seq_id[s0][s] == seq_id) {
-                                        if (hparams.use_alibi) {
-                                            f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
-                                        } else {
-                                            f = 0.0f;
-                                        }
-                                        break;
-                                    }
-                                }
-
-                                data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
-                            }
-                        }
-
-                        for (int i = n_tokens; i < n_stride; ++i) {
-                            data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
-                        }
+            // mask padded tokens
+            if (data_swa) {
+                for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                    for (int j = 0; j < n_kv; ++j) {
+                        data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
                     }
                 }
             }