]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add support for EmbeddingGemma 300m (#15798)
authorDaniel Bevenius <redacted>
Thu, 4 Sep 2025 16:10:29 +0000 (18:10 +0200)
committerGitHub <redacted>
Thu, 4 Sep 2025 16:10:29 +0000 (18:10 +0200)
This commit add support for the EmbeddingGemma 300m. This model supports
sliding window attention (SWA) and a new swq_type is introduced to
support symmetric SWA masking.

This commit also extracts the code from the function
llama_is_masked_swa in llama-impl.h, so that the logic can be shared
by both llm_graph_input_attn_no_cache::set_input and
llama_kv_cache::set_input_kq_mask.

With this commit the EmbeddingGemma 300m model can be converted to
to GGUF and used with llama.cpp.

Once the model has been uploaded to HuggingFace it can be used like
this:
```console
./build/bin/llama-cli -hf ggml-org/embeddinggemma-300m-GGUF:Q8_0
```

15 files changed:
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-graph.cpp
src/llama-graph.h
src/llama-hparams.cpp
src/llama-hparams.h
src/llama-kv-cache-iswa.cpp
src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-memory-hybrid.cpp
src/llama-memory-hybrid.h
src/llama-model.cpp

index 908435d7e5f00ce1d370e7c368e7dd47a0172b7a..717b8a6595010fb9270da734df372a2f7f41c1bc 100755 (executable)
@@ -5122,6 +5122,15 @@ class Gemma3Model(TextModel):
         return [(self.map_tensor_name(name), data_torch)]
 
 
+@ModelBase.register("Gemma3TextModel")
+class EmbeddingGemma(Gemma3Model):
+    model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self._try_set_pooling_type()
+
+
 @ModelBase.register("Gemma3ForConditionalGeneration")
 class Gemma3VisionModel(MmprojModel):
     def set_gguf_parameters(self):
index 6156d35c2ad67d40fdd32510d6e1aacff3c7bb5d..c922bacf6eb973488b26a55b590c6e8ce0fa190d 100644 (file)
@@ -340,6 +340,7 @@ class MODEL_ARCH(IntEnum):
     GEMMA2           = auto()
     GEMMA3           = auto()
     GEMMA3N          = auto()
+    GEMMA_EMBEDDING  = auto()
     STARCODER2       = auto()
     RWKV6            = auto()
     RWKV6QWEN2       = auto()
@@ -674,6 +675,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.GEMMA2:           "gemma2",
     MODEL_ARCH.GEMMA3:           "gemma3",
     MODEL_ARCH.GEMMA3N:          "gemma3n",
+    MODEL_ARCH.GEMMA_EMBEDDING:  "gemma-embedding",
     MODEL_ARCH.STARCODER2:       "starcoder2",
     MODEL_ARCH.RWKV6:            "rwkv6",
     MODEL_ARCH.RWKV6QWEN2:       "rwkv6qwen2",
@@ -1719,6 +1721,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.LAUREL_R,
         MODEL_TENSOR.LAUREL_POST_NORM,
     ],
+    MODEL_ARCH.GEMMA_EMBEDDING: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_POST_NORM,
+        MODEL_TENSOR.FFN_PRE_NORM,
+        MODEL_TENSOR.FFN_POST_NORM,
+    ],
     MODEL_ARCH.STARCODER2: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 497f48809f8a799fa3f5fc3db802aeb9db8ecb5a..b0c3d65e95847d9aa87250d82c2e19e1dd9d3bd6 100644 (file)
@@ -14,6 +14,7 @@ class TensorNameMap:
             "transformer.word_embeddings",               # falcon
             "word_embeddings",                           # bloom
             "model.embed_tokens",                        # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid
+            "embed_tokens",                              # embeddinggemma
             "tok_embeddings",                            # llama-pth
             "embeddings.word_embeddings",                # bert nomic-bert
             "language_model.embedding.word_embeddings",  # persimmon
@@ -141,6 +142,7 @@ class TensorNameMap:
             "rwkv.blocks.{bid}.ln1",                                # rwkv6
             "model.layers.{bid}.ln1",                               # rwkv7
             "model.layers.{bid}.input_layernorm",                   # llama4
+            "layers.{bid}.input_layernorm",                         # embeddinggemma
             "transformer_encoder.{bid}.attention_norm",             # neobert
             "model.layers.{bid}.operator_norm",                     # lfm2
             "model.transformer.blocks.{bid}.attn_norm",             # llada
@@ -179,6 +181,7 @@ class TensorNameMap:
         # Attention query
         MODEL_TENSOR.ATTN_Q: (
             "model.layers.{bid}.self_attn.q_proj",                       # llama-hf nemotron olmoe olmo2 phimoe
+            "layers.{bid}.self_attn.q_proj",                             # embeddinggemma
             "model.layers.{bid}.self_attn.q_proj_no_perm",               # llama-custom
             "layers.{bid}.attention.wq",                                 # llama-pth
             "encoder.layer.{bid}.attention.self.query",                  # bert
@@ -197,6 +200,7 @@ class TensorNameMap:
         # Attention key
         MODEL_TENSOR.ATTN_K: (
             "model.layers.{bid}.self_attn.k_proj",                     # llama-hf nemotron olmoe olmo2 phimoe
+            "layers.{bid}.self_attn.k_proj",                           # embeddinggemma
             "model.layers.{bid}.self_attn.k_proj_no_perm",             # llama-custom
             "layers.{bid}.attention.wk",                               # llama-pth
             "encoder.layer.{bid}.attention.self.key",                  # bert
@@ -216,6 +220,7 @@ class TensorNameMap:
         # Attention value
         MODEL_TENSOR.ATTN_V: (
             "model.layers.{bid}.self_attn.v_proj",                       # llama-hf nemotron olmoe olmo2 phimoe
+            "layers.{bid}.self_attn.v_proj",                             # embeddinggemma
             "layers.{bid}.attention.wv",                                 # llama-pth
             "encoder.layer.{bid}.attention.self.value",                  # bert
             "transformer.layer.{bid}.attention.v_lin",                   # distillbert
@@ -239,6 +244,7 @@ class TensorNameMap:
             "transformer.h.{bid}.self_attention.dense",                     # falcon
             "h.{bid}.self_attention.dense",                                 # bloom
             "model.layers.{bid}.self_attn.o_proj",                          # llama-hf nemotron olmoe olmo2 phimoe
+            "layers.{bid}.self_attn.o_proj",                                # embeddinggemma
             "model.layers.{bid}.self_attn.out_proj",                        # lfm2
             "model.layers.{bid}.self_attn.linear_attn",                     # deci
             "layers.{bid}.attention.wo",                                    # llama-pth
@@ -277,6 +283,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.ATTN_POST_NORM: (
             "model.layers.{bid}.post_attention_layernorm",       # gemma2 olmo2    # ge
+            "layers.{bid}.post_attention_layernorm",             # embeddinggemma
             "model.layers.{bid}.post_self_attn_layernorm",       # glm-4-0414
             "model.layers.layers.{bid}.post_mixer_norm.weight",  # plamo2
         ),
@@ -320,12 +327,14 @@ class TensorNameMap:
         # Post feed-forward norm
         MODEL_TENSOR.FFN_PRE_NORM: (
             "model.layers.{bid}.pre_feedforward_layernorm", # gemma2
+            "layers.{bid}.pre_feedforward_layernorm",       # embeddinggemma
             "model.layers.{bid}.pre_ff_layernorm.weight",
         ),
 
         # Post feed-forward norm
         MODEL_TENSOR.FFN_POST_NORM: (
             "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
+            "layers.{bid}.post_feedforward_layernorm",       # embeddinggemma
             "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
             "model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
             "model.layers.{bid}.feed_forward.up_proj",
@@ -362,6 +371,7 @@ class TensorNameMap:
             "transformer.h.{bid}.mlp.dense_h_to_4h",                  # falcon
             "h.{bid}.mlp.dense_h_to_4h",                              # bloom
             "model.layers.{bid}.mlp.up_proj",                         # llama-hf refact nemotron olmo2
+            "layers.{bid}.mlp.up_proj",                               # embeddinggemma
             "layers.{bid}.feed_forward.w3",                           # llama-pth
             "encoder.layer.{bid}.intermediate.dense",                 # bert
             "transformer.layer.{bid}.ffn.lin1",                       # distillbert
@@ -421,6 +431,7 @@ class TensorNameMap:
         # Feed-forward gate
         MODEL_TENSOR.FFN_GATE: (
             "model.layers.{bid}.mlp.gate_proj",           # llama-hf refact olmo2
+            "layers.{bid}.mlp.gate_proj",                 # embeddinggemma
             "layers.{bid}.feed_forward.w1",               # llama-pth
             "transformer.h.{bid}.mlp.w2",                 # qwen
             "transformer.h.{bid}.mlp.c_fc2",              # jais
@@ -461,6 +472,7 @@ class TensorNameMap:
             "transformer.h.{bid}.mlp.dense_4h_to_h",                  # falcon
             "h.{bid}.mlp.dense_4h_to_h",                              # bloom
             "model.layers.{bid}.mlp.down_proj",                       # llama-hf nemotron olmo2
+            "layers.{bid}.mlp.down_proj",                             # embeddinggemma
             "layers.{bid}.feed_forward.w2",                           # llama-pth
             "encoder.layer.{bid}.output.dense",                       # bert
             "transformer.layer.{bid}.ffn.lin2",                       # distillbert
@@ -513,6 +525,7 @@ class TensorNameMap:
             "model.layers.{bid}.self_attn.q_layernorm",                       # persimmon
             "model.layers.{bid}.self_attn.query_layernorm",                   # hunyuan
             "model.layers.{bid}.self_attn.q_norm",                            # cohere olmoe chameleon olmo2
+            "layers.{bid}.self_attn.q_norm",                                  # embeddinggemma
             "transformer.blocks.{bid}.attn.q_ln",                             # sea-lion
             "encoder.layer.{bid}.attention.self.layer_norm_q",                # jina-bert-v2
             "transformer.layers.{bid}.attn.q_norm",                           # openelm
@@ -525,6 +538,7 @@ class TensorNameMap:
             "model.layers.{bid}.self_attn.k_layernorm",                       # persimmon
             "model.layers.{bid}.self_attn.key_layernorm",                     # hunyuan
             "model.layers.{bid}.self_attn.k_norm",                            # cohere olmoe chameleon olmo2
+            "layers.{bid}.self_attn.k_norm",                                  # embeddinggemma
             "transformer.blocks.{bid}.attn.k_ln",                             # sea-lion
             "encoder.layer.{bid}.attention.self.layer_norm_k",                # jina-bert-v2
             "transformer.layers.{bid}.attn.k_norm",                           # openelm
index d5c8477f4aa39e97f128f6e96fc9ffff3ce28686..d92bf2dd8515fcb0a616b0210cbc53ddf007bcbd 100644 (file)
@@ -45,6 +45,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_GEMMA2,           "gemma2"           },
     { LLM_ARCH_GEMMA3,           "gemma3"           },
     { LLM_ARCH_GEMMA3N,          "gemma3n"          },
+    { LLM_ARCH_GEMMA_EMBEDDING,  "gemma-embedding"  },
     { LLM_ARCH_STARCODER2,       "starcoder2"       },
     { LLM_ARCH_MAMBA,            "mamba"            },
     { LLM_ARCH_MAMBA2,           "mamba2"           },
@@ -1038,6 +1039,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_LAUREL_POST_NORM,     "blk.%d.laurel_post_norm" },
         },
     },
+    {
+        LLM_ARCH_GEMMA_EMBEDDING,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
+        },
+    },
     {
         LLM_ARCH_STARCODER2,
         {
index 86c119692d8cc0e5bfdb23c2e0c0cc21836435f6..41f377923f50f7dc564c6b80a9ab3bf560c82eab 100644 (file)
@@ -49,6 +49,7 @@ enum llm_arch {
     LLM_ARCH_GEMMA2,
     LLM_ARCH_GEMMA3,
     LLM_ARCH_GEMMA3N,
+    LLM_ARCH_GEMMA_EMBEDDING,
     LLM_ARCH_STARCODER2,
     LLM_ARCH_MAMBA,
     LLM_ARCH_MAMBA2,
index 49ea5da7cb42219f92d764fddfe36e6d293b7dfb..7ce2960eb311b22c9b06f0d7678b507868aef1ef 100644 (file)
@@ -258,6 +258,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
+    LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
+    const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
+                          (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
+                          (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
+                          (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
+    LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
+    LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
+    LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
+
+    LLAMA_LOG_DEBUG("    ");
+    for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
+        LLAMA_LOG_DEBUG("%2d", j);
+    }
+    LLAMA_LOG_DEBUG("\n");
+
+    for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
+        LLAMA_LOG_DEBUG(" %2d ", i);
+        for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
+            float val = data[i * n_kv + j];
+            if (val == -INFINITY) {
+                LLAMA_LOG_DEBUG(" âˆž");
+            } else {
+                LLAMA_LOG_DEBUG(" 0");
+            }
+        }
+        LLAMA_LOG_DEBUG("\n");
+    }
+}
+
 void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
     const int64_t n_kv     = ubatch->n_tokens;
     const int64_t n_tokens = ubatch->n_tokens;
@@ -277,21 +307,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
                 for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
                     const llama_seq_id s0 = ubatch->seq_id[i0][0];
 
-                    // TODO: reimplement this like in llama_kv_cache
-                    if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
-                        if (hparams.use_alibi) {
-                            f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
-                        } else {
-                            f = 0.0f;
-                        }
-                        break;
+                    if (s0 != s1) {
+                        continue; // skip different sequences
                     }
-                }
 
+                    if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
+                        continue; // skip future tokens for causal attention
+                    }
+
+                    if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
+                        continue; // skip masked tokens for SWA
+                    }
+
+                    // TODO: reimplement this like in llama_kv_cache_unified
+                    if (hparams.use_alibi) {
+                        f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
+                    } else {
+                        f = 0.0f;
+                    }
+                }
                 data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
             }
         }
     }
+    if (debug) {
+        print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
+    }
 }
 
 void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
index 3c85333fde5145a424949dda079d3d8e15d33af9..ca90fdf613f6de0074e465ca414e7b80c070e85a 100644 (file)
@@ -78,6 +78,11 @@ struct llm_graph_params;
 
 class llm_graph_input_i {
 public:
+    llm_graph_input_i() {
+        const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
+        debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
+    }
+
     virtual ~llm_graph_input_i() = default;
 
     virtual void set_input(const llama_ubatch * ubatch) = 0;
@@ -90,6 +95,9 @@ public:
         GGML_UNUSED(params);
         return false;
     }
+protected:
+    // env: LLAMA_GRAPH_INPUT_DEBUG
+    int debug = 0;
 };
 
 using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
index 91636572da8b29d61722c59c301806cd280397d0..4b7a65362b9f04d17d0300789e69902a8c302a1c 100644 (file)
@@ -1,6 +1,7 @@
 #include "llama-hparams.h"
 
 #include "ggml.h"
+#include <cassert>
 
 void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
     if (dense_first) {
@@ -178,3 +179,39 @@ uint32_t llama_hparams::n_layer_kv() const {
 
     return res;
 }
+
+bool llama_hparams::is_masked_swa(llama_pos p0, llama_pos p1) const {
+    assert(p0 >= 0 && p1 >= 0);
+
+    switch (swa_type) {
+        case LLAMA_SWA_TYPE_NONE:
+            {
+            } break;
+        case LLAMA_SWA_TYPE_STANDARD:
+            {
+                if (p1 - p0 >= (int32_t) n_swa) {
+                    return true;
+                }
+            } break;
+        case LLAMA_SWA_TYPE_CHUNKED:
+            {
+                const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
+
+                if (p0 < pos_chunk_start) {
+                    return true;
+                }
+            } break;
+        case LLAMA_SWA_TYPE_SYMMETRIC:
+            {
+                const int32_t half_n_swa = (int32_t) n_swa / 2;
+                const int32_t pos_diff = p1 - p0;
+
+                // Mask if outside the symmetric window
+                if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
+                    return true;
+                }
+            } break;
+    }
+
+    return false;
+}
index 60415f0c202a4a039456fceaea898979dd2d396f..06d1e51db055234e7a5ada4bfa2a2f8fce4a084f 100644 (file)
@@ -16,9 +16,10 @@ enum llama_expert_gating_func_type {
 };
 
 enum llama_swa_type {
-    LLAMA_SWA_TYPE_NONE     = 0,
-    LLAMA_SWA_TYPE_STANDARD = 1,
-    LLAMA_SWA_TYPE_CHUNKED  = 2,
+    LLAMA_SWA_TYPE_NONE      = 0,
+    LLAMA_SWA_TYPE_STANDARD  = 1,
+    LLAMA_SWA_TYPE_CHUNKED   = 2,
+    LLAMA_SWA_TYPE_SYMMETRIC = 3,
 };
 
 struct llama_hparams_posnet {
@@ -227,6 +228,8 @@ struct llama_hparams {
 
     // number of layers for which has_kv() returns true
     uint32_t n_layer_kv() const;
+
+    bool is_masked_swa(llama_pos p0, llama_pos p1) const;
 };
 
 static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
index d7342914c6b7cbf17eeb8cae258b12413e16d6e7..7c51a1981e05e66688ef0a6edd3afcc237802fca 100644 (file)
@@ -60,14 +60,14 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
     kv_base = std::make_unique<llama_kv_cache>(
             model, type_k, type_v,
             v_trans, offload, unified, size_base, n_seq_max, n_pad,
-            0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
+            0, filter_base, reuse);
 
     LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
 
     kv_swa = std::make_unique<llama_kv_cache>(
             model, type_k, type_v,
             v_trans, offload, unified, size_swa, n_seq_max, n_pad,
-            hparams.n_swa, hparams.swa_type, filter_swa, reuse);
+            hparams.n_swa, filter_swa, reuse);
 }
 
 void llama_kv_cache_iswa::clear(bool data) {
index f1c6918738684c1ace92fa7c010a437423791abc..1564faedf4eb48658ca6c9ee5c979a01513d9efd 100644 (file)
@@ -27,11 +27,10 @@ llama_kv_cache::llama_kv_cache(
                  uint32_t   n_seq_max,
                  uint32_t   n_pad,
                  uint32_t   n_swa,
-           llama_swa_type   swa_type,
     const layer_filter_cb & filter,
     const  layer_reuse_cb & reuse) :
     model(model), hparams(model.hparams), v_trans(v_trans),
-    n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
+    n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa) {
 
     GGML_ASSERT(kv_size % n_pad == 0);
 
@@ -1393,29 +1392,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
 }
 
 bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
-    assert(p0 >= 0 && p1 >= 0);
-
-    switch (swa_type) {
-        case LLAMA_SWA_TYPE_NONE:
-            {
-            } break;
-        case LLAMA_SWA_TYPE_STANDARD:
-            {
-                if (p1 - p0 >= (int32_t) n_swa) {
-                    return true;
-                }
-            } break;
-        case LLAMA_SWA_TYPE_CHUNKED:
-            {
-                const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
-
-                if (p0 < pos_chunk_start) {
-                    return true;
-                }
-            } break;
-    }
-
-    return false;
+    return hparams.is_masked_swa(p0, p1);
 }
 
 void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
index 07d29bb8187e27d1e20a3c36c1bcb2a142591398..55ee355b22928059696b59925fd43848d5adfb3a 100644 (file)
@@ -89,7 +89,6 @@ public:
                      uint32_t   n_seq_max,
                      uint32_t   n_pad,
                      uint32_t   n_swa,
-               llama_swa_type   swa_type,
         const layer_filter_cb & filter,
         const  layer_reuse_cb & reuse);
 
@@ -212,8 +211,6 @@ private:
     // env: LLAMA_KV_CACHE_DEBUG
     int debug = 0;
 
-    const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
-
     std::vector<ggml_context_ptr>        ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
 
index ba61ebaa885feffc62ac012563736f59fa4325eb..38f92da11820d5ce1275968adb7f90433169f4f6 100644 (file)
@@ -17,7 +17,6 @@ llama_memory_hybrid::llama_memory_hybrid(
                  uint32_t   kv_size,
                  uint32_t   n_pad,
                  uint32_t   n_swa,
-           llama_swa_type   swa_type,
                             /* recurrent */
                 ggml_type   type_r,
                 ggml_type   type_s,
@@ -41,7 +40,6 @@ llama_memory_hybrid::llama_memory_hybrid(
         n_seq_max,
         n_pad,
         n_swa,
-        swa_type,
         filter_attn == nullptr ?
             [&](int32_t il) { return !hparams.is_recurrent(il); }
             : filter_attn,
index 11a35651782974023d48be9ddba531c0c79e0a26..0eb63f5ef86d7ee2bccb536b497e16672e7cbb7c 100644 (file)
@@ -27,7 +27,6 @@ public:
                  uint32_t   kv_size,
                  uint32_t   n_pad,
                  uint32_t   n_swa,
-           llama_swa_type   swa_type,
                             /* recurrent */
                 ggml_type   type_r,
                 ggml_type   type_s,
index 5e54ce25f12988e2d6c684dc400c891e73008573..afc4cb48098d21e0933813f9342de8194ac3fbf8 100644 (file)
@@ -1142,6 +1142,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_GEMMA_EMBEDDING:
+            {
+                hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
+                hparams.set_swa_pattern(6);
+
+                hparams.causal_attn = false; // embeddings do not use causal attention
+                hparams.rope_freq_base_train_swa  = 10000.0f;
+                hparams.rope_freq_scale_train_swa = 1.0f;
+
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_POOLING_TYPE,                hparams.pooling_type);
+
+                switch (hparams.n_layer) {
+                    case 24: type = LLM_TYPE_0_3B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+                hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
+
+            } break;
         case LLM_ARCH_STARCODER2:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -3484,6 +3504,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                     }
                 } break;
             case LLM_ARCH_GEMMA3:
+            case LLM_ARCH_GEMMA_EMBEDDING:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
@@ -11045,6 +11066,136 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
     }
 };
 
+struct llm_build_gemma_embedding_iswa : public llm_graph_context {
+    llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_k;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
+        if (ubatch.token) {
+            inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+            cb(inpL, "inp_scaled", -1);
+        }
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_no_cache();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+            const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+            // norm
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
+                Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
+
+                cur = build_attn(inp_attn,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+            cb(sa_out, "sa_out", il);
+
+            cur = build_norm(sa_out,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, sa_out);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 // TODO: move up next to build_starcoder
 struct llm_build_starcoder2 : public llm_graph_context {
     llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
@@ -18481,6 +18632,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         case LLM_ARCH_NOMIC_BERT_MOE:
         case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_WAVTOKENIZER_DEC:
+        case LLM_ARCH_GEMMA_EMBEDDING:
         case LLM_ARCH_DREAM:
         case LLM_ARCH_LLADA:
             {
@@ -18529,7 +18681,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         /* attn_kv_size      */ cparams.n_ctx,
                         /* attn_n_pad        */ padding,
                         /* attn_n_swa        */ hparams.n_swa,
-                        /* attn_swa_type     */ hparams.swa_type,
                         /* recurrent_type_k  */ GGML_TYPE_F32,
                         /* recurrent_type_v  */ GGML_TYPE_F32,
                         /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
@@ -18599,7 +18750,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 cparams.n_seq_max,
                                 padding,
                                 hparams.n_swa,
-                                hparams.swa_type,
                                 nullptr,
                                 nullptr);
                     }
@@ -18761,6 +18911,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params);
             } break;
+        case LLM_ARCH_GEMMA_EMBEDDING:
+            {
+                llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
+            } break;
         case LLM_ARCH_STARCODER2:
             {
                 llm = std::make_unique<llm_build_starcoder2>(*this, params);
@@ -19161,6 +19315,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_GEMMA2:
         case LLM_ARCH_GEMMA3:
         case LLM_ARCH_GEMMA3N:
+        case LLM_ARCH_GEMMA_EMBEDDING:
         case LLM_ARCH_STARCODER2:
         case LLM_ARCH_OPENELM:
         case LLM_ARCH_GPTNEOX: