From: Georgi Gerganov Date: Thu, 17 Jul 2025 06:49:15 +0000 (+0300) Subject: kv-cache : opt mask set input (#14600) X-Git-Tag: upstream/0.0.6073~153 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=d9b691081c04ec5fb0daa9d2b979f915c142963d;p=pkg%2Fggml%2Fsources%2Fllama.cpp kv-cache : opt mask set input (#14600) ggml-ci --- diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 7e92e6b4..baaa1d32 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -1283,6 +1283,8 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub const int64_t n_tps = n_tokens/n_stream; const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD); + std::fill(data, data + ggml_nelements(dst), -INFINITY); + // Use only the previous KV cells of the correct sequence for each token of the ubatch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: @@ -1306,44 +1308,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub const llama_pos p1 = ubatch->pos[i]; - for (uint32_t j = 0; j < n_kv; ++j) { - float f = 0.0f; - - bool masked = false; + const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii); + for (uint32_t j = 0; j < n_kv; ++j) { if (cells.is_empty(j)) { - masked = true; - } else { - const llama_pos p0 = cells.pos_get(j); - - // mask the token if not the same sequence - masked = masked || (!cells.seq_has(j, seq_id)); + continue; + } - // mask future tokens - masked = masked || (causal_attn && p0 > p1); + // mask the token if not the same sequence + if (!cells.seq_has(j, seq_id)) { + continue; + } - // apply SWA if any - masked = masked || (is_masked_swa(p0, p1)); + const llama_pos p0 = cells.pos_get(j); - if (!masked && hparams.use_alibi) { - f = -std::abs(p0 - p1); - } + // mask future tokens + if (causal_attn && p0 > p1) { + continue; } - if (masked) { - f = -INFINITY; + // apply SWA if any + if (is_masked_swa(p0, p1)) { + continue; } - data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f; - } - - // mask padded tokens - if (data) { - for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) { - for (uint32_t j = 0; j < n_kv; ++j) { - data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY; - } - } + data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; } } }