]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : opt mask set input (#14600)
authorGeorgi Gerganov <redacted>
Thu, 17 Jul 2025 06:49:15 +0000 (09:49 +0300)
committerGitHub <redacted>
Thu, 17 Jul 2025 06:49:15 +0000 (09:49 +0300)
ggml-ci

src/llama-kv-cache-unified.cpp

index 7e92e6b4df9d4419578f940bafd4c4290f21389c..baaa1d32dffb55a2bd83c7f5ffce872e39590a3b 100644 (file)
@@ -1283,6 +1283,8 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
     const int64_t n_tps     = n_tokens/n_stream;
     const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
 
+    std::fill(data, data + ggml_nelements(dst), -INFINITY);
+
     // Use only the previous KV cells of the correct sequence for each token of the ubatch.
     // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
     // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
@@ -1306,44 +1308,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
 
                 const llama_pos p1 = ubatch->pos[i];
 
-                for (uint32_t j = 0; j < n_kv; ++j) {
-                    float f = 0.0f;
-
-                    bool masked = false;
+                const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
 
+                for (uint32_t j = 0; j < n_kv; ++j) {
                     if (cells.is_empty(j)) {
-                        masked = true;
-                    } else {
-                        const llama_pos p0 = cells.pos_get(j);
-
-                        // mask the token if not the same sequence
-                        masked = masked || (!cells.seq_has(j, seq_id));
+                        continue;
+                    }
 
-                        // mask future tokens
-                        masked = masked || (causal_attn && p0 > p1);
+                    // mask the token if not the same sequence
+                    if (!cells.seq_has(j, seq_id)) {
+                        continue;
+                    }
 
-                        // apply SWA if any
-                        masked = masked || (is_masked_swa(p0, p1));
+                    const llama_pos p0 = cells.pos_get(j);
 
-                        if (!masked && hparams.use_alibi) {
-                            f = -std::abs(p0 - p1);
-                        }
+                    // mask future tokens
+                    if (causal_attn && p0 > p1) {
+                        continue;
                     }
 
-                    if (masked) {
-                        f = -INFINITY;
+                    // apply SWA if any
+                    if (is_masked_swa(p0, p1)) {
+                        continue;
                     }
 
-                    data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f;
-                }
-
-                // mask padded tokens
-                if (data) {
-                    for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) {
-                        for (uint32_t j = 0; j < n_kv; ++j) {
-                            data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY;
-                        }
-                    }
+                    data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
                 }
             }
         }