]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Tue, 14 Oct 2025 19:09:02 +0000 (22:09 +0300)
committerGeorgi Gerganov <redacted>
Wed, 15 Oct 2025 06:29:17 +0000 (09:29 +0300)
examples/talk-llama/llama-graph.cpp
examples/talk-llama/llama-graph.h
examples/talk-llama/llama-model.cpp
examples/talk-llama/llama.cpp

index a24853c63ada4006f10b059a71e5bc22750ed2c0..f29a1e98c9103e1135962d60aa5ff0a2e054e372 100644 (file)
@@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
     }
 }
 
-static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
+static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
     LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
-    const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
-                          (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
-                          (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
-                          (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
+    const char * swa_type_str = "unknown";
+
+    switch (swa_type) {
+        case LLAMA_SWA_TYPE_NONE:      swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
+        case LLAMA_SWA_TYPE_STANDARD:  swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
+        case LLAMA_SWA_TYPE_CHUNKED:   swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
+        case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
+    };
+
     LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
     LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
     LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@@ -295,50 +300,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
     const int64_t n_kv     = ubatch->n_tokens;
     const int64_t n_tokens = ubatch->n_tokens;
 
-    GGML_ASSERT(kq_mask);
-    GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
-
-    float * data = (float *) kq_mask->data;
-
-    // [TAG_NO_CACHE_ISWA]
-    GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
+    const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
+        for (int h = 0; h < 1; ++h) {
+            for (int i1 = 0; i1 < n_tokens; ++i1) {
+                const llama_seq_id s1 = ubatch->seq_id[i1][0];
+                const llama_pos    p1 = ubatch->pos[i1];
 
-    for (int h = 0; h < 1; ++h) {
-        for (int i1 = 0; i1 < n_tokens; ++i1) {
-            const llama_seq_id s1 = ubatch->seq_id[i1][0];
+                const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
 
-            for (int i0 = 0; i0 < n_tokens; ++i0) {
-                float f = -INFINITY;
-
-                for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
+                for (int i0 = 0; i0 < n_tokens; ++i0) {
                     const llama_seq_id s0 = ubatch->seq_id[i0][0];
+                    const llama_pos p0    = ubatch->pos[i0];
 
+                    // mask different sequences
                     if (s0 != s1) {
-                        continue; // skip different sequences
+                        continue;
                     }
 
-                    if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
-                        continue; // skip future tokens for causal attention
+                    // mask future tokens
+                    if (cparams.causal_attn && p0 > p1) {
+                        continue;
                     }
 
-                    // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
-                    //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
-                    //    continue; // skip masked tokens for SWA
-                    //}
-
-                    // TODO: reimplement this like in llama_kv_cache_unified
-                    if (hparams.use_alibi) {
-                        f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
-                    } else {
-                        f = 0.0f;
+                    // apply SWA if any
+                    if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
+                        continue;
                     }
+
+                    data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
                 }
-                data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
             }
         }
+    };
+
+    {
+        GGML_ASSERT(self_kq_mask);
+        GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
+
+        float * data = (float *) self_kq_mask->data;
+
+        std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
+
+        fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
+
+        if (debug) {
+            print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
+        }
     }
-    if (debug) {
-        print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
+
+    if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
+        GGML_ASSERT(self_kq_mask_swa);
+        GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
+
+        float * data = (float *) self_kq_mask_swa->data;
+
+        std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
+
+        fill_mask(data, hparams.n_swa, hparams.swa_type);
+
+        if (debug) {
+            print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
+        }
     }
 }
 
@@ -1299,12 +1321,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
     k = ggml_permute(ctx0, k, 0, 2, 1, 3);
     v = ggml_permute(ctx0, v, 0, 2, 1, 3);
 
-    const auto n_kv = k->ne[1];
-
     ggml_tensor * cur;
 
-    // TODO: replace hardcoded padding with ggml-provided padding
-    if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
+    if (cparams.flash_attn && kq_b == nullptr) {
         GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
 
         if (v_trans) {
@@ -1419,10 +1438,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
     auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
 
     // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
-    inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
-    ggml_set_input(inp->kq_mask);
+    inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+    ggml_set_input(inp->self_kq_mask);
+
+    inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
 
-    inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
+    if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
+        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+        ggml_set_input(inp->self_kq_mask_swa);
+
+        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
+    } else {
+        inp->self_kq_mask_swa     = nullptr;
+        inp->self_kq_mask_swa_cnv = nullptr;
+    }
 
     return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
 }
@@ -1447,7 +1476,9 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_build_forward_expand(gf, k_cur);
     ggml_build_forward_expand(gf, v_cur);
 
-    const auto & kq_mask = inp->get_kq_mask();
+    const bool is_swa = hparams.is_swa(il);
+
+    const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
 
     // [TAG_NO_CACHE_PAD]
     // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
index dc84b7942893a10d4084950c70d529ebcd051c67..d0c3934f67927e1f8218790d110fc60fe0619d93 100644 (file)
@@ -257,10 +257,14 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
-    ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
+    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
+    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
 
-    ggml_tensor * kq_mask     = nullptr; // F32 [n_tokens, n_batch, 1, 1]
-    ggml_tensor * kq_mask_cnv = nullptr; //     [n_tokens, n_batch, 1, 1]
+    // n_tokens == n_batch
+    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_tokens, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_tokens, n_batch/n_stream, 1, n_stream]
 
     const llama_hparams hparams;
     const llama_cparams cparams;
index 36d495d6cfeab7ccb5bbb76922acd7f33e837486..0cdad9babd9b27c083488ecefd25f0ee02d1cbfa 100644 (file)
@@ -11358,8 +11358,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
     }
 };
 
-struct llm_build_gemma_embedding_iswa : public llm_graph_context {
-    llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+struct llm_build_gemma_embedding : public llm_graph_context {
+    llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_k;
 
         ggml_tensor * cur;
@@ -11376,8 +11376,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        // TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
-        auto * inp_attn = build_attn_inp_kv_iswa();
+        auto * inp_attn = build_attn_inp_no_cache();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -19378,7 +19377,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         case LLM_ARCH_NOMIC_BERT_MOE:
         case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_WAVTOKENIZER_DEC:
-        //case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
+        case LLM_ARCH_GEMMA_EMBEDDING:
         case LLM_ARCH_DREAM:
         case LLM_ARCH_LLADA:
         case LLM_ARCH_LLADA_MOE:
@@ -19671,7 +19670,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             } break;
         case LLM_ARCH_GEMMA_EMBEDDING:
             {
-                llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
+                llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
             } break;
         case LLM_ARCH_STARCODER2:
             {
index fe5a7a835488c5a16c4219521fd20a472e7b202d..38700f97a068818f186fffe9b0480866e2ae75a0 100644 (file)
@@ -312,6 +312,7 @@ struct llama_model * llama_model_load_from_splits(
         LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
         return nullptr;
     }
+    splits.reserve(n_paths);
     for (size_t i = 0; i < n_paths; ++i) {
         splits.push_back(paths[i]);
     }