]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : add LFM2-ColBert-350M (#18607)
authorTarek Dakhran <redacted>
Mon, 5 Jan 2026 18:52:56 +0000 (19:52 +0100)
committerGitHub <redacted>
Mon, 5 Jan 2026 18:52:56 +0000 (19:52 +0100)
* model : add LFM2-ColBert-350M

* llama_model_n_embd_out() - returns `hparams.n_embd_out` if set and fallbacks to `hparams.n_embd`

16 files changed:
convert_hf_to_gguf.py
examples/embedding/embedding.cpp
examples/model-conversion/logits.cpp
examples/retrieval/retrieval.cpp
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
include/llama.h
src/llama-arch.cpp
src/llama-arch.h
src/llama-context.cpp
src/llama-graph.cpp
src/llama-hparams.cpp
src/llama-hparams.h
src/llama-model-saver.cpp
src/llama-model.cpp
tools/server/server-context.cpp

index d944032c6985119ee24b07b570db41a404b985b3..d9ee390b3835c110946249fce0ed42eef1ea7ee0 100755 (executable)
@@ -9956,6 +9956,27 @@ class LFM2Model(TextModel):
         return any(p in name for p in ["audio", "codebook", "conformer", "depth_embedding", "depthformer", "depth_linear"])
 
 
+@ModelBase.register("Lfm2Model")
+class LFM2ColBertModel(LFM2Model):
+    model_arch = gguf.MODEL_ARCH.LFM2
+    dense_tensor_name = "dense_2"
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        if not name.startswith(self.dense_tensor_name):
+            name = "model." + name
+
+        return super().modify_tensors(data_torch, name, bid)
+
+    def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
+        # dense tensor is stored in a separate safetensors file
+        from safetensors.torch import load_file
+        tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
+        assert tensors_file.is_file()
+        tensor = load_file(tensors_file)["linear.weight"]
+        self.gguf_writer.add_embedding_length_out(tensor.shape[0])
+        yield f"{self.dense_tensor_name}.weight", tensor.clone()
+
+
 @ModelBase.register("Lfm2MoeForCausalLM")
 class LFM2MoeModel(TextModel):
     model_arch = gguf.MODEL_ARCH.LFM2MOE
index 81111e81b2ce1f085d60e543f3938ba45cb5180b..d8eaaa2691f94519c6e273ed9e36de9c0bb02f3f 100644 (file)
@@ -33,7 +33,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
     }
 }
 
-static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
+static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) {
     const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
 
     // clear previous kv_cache values (irrelevant for embeddings)
@@ -65,8 +65,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
             GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
         }
 
-        float * out = output + embd_pos * n_embd;
-        common_embd_normalize(embd, out, n_embd, embd_norm);
+        float * out = output + embd_pos * n_embd_out;
+        common_embd_normalize(embd, out, n_embd_out, embd_norm);
     }
 }
 
@@ -252,8 +252,8 @@ int main(int argc, char ** argv) {
     }
 
     // allocate output
-    const int n_embd = llama_model_n_embd(model);
-    std::vector<float> embeddings(n_embd_count * n_embd, 0);
+    const int n_embd_out = llama_model_n_embd_out(model);
+    std::vector<float> embeddings(n_embd_count * n_embd_out, 0);
     float * emb = embeddings.data();
 
     // break into batches
@@ -267,8 +267,8 @@ int main(int argc, char ** argv) {
 
         // encode if at capacity
         if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
-            float * out = emb + e * n_embd;
-            batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
+            float * out = emb + e * n_embd_out;
+            batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
             e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
             s = 0;
             common_batch_clear(batch);
@@ -280,8 +280,8 @@ int main(int argc, char ** argv) {
     }
 
     // final batch
-    float * out = emb + e * n_embd;
-    batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
+    float * out = emb + e * n_embd_out;
+    batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
 
     if (params.embd_out.empty()) {
         LOG("\n");
@@ -289,19 +289,19 @@ int main(int argc, char ** argv) {
         if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
             for (int j = 0; j < n_embd_count; j++) {
                 LOG("embedding %d: ", j);
-                for (int i = 0; i < std::min(3, n_embd); i++) {
+                for (int i = 0; i < std::min(3, n_embd_out); i++) {
                     if (params.embd_normalize == 0) {
-                        LOG("%6.0f ", emb[j * n_embd + i]);
+                        LOG("%6.0f ", emb[j * n_embd_out + i]);
                     } else {
-                        LOG("%9.6f ", emb[j * n_embd + i]);
+                        LOG("%9.6f ", emb[j * n_embd_out + i]);
                     }
                 }
                 LOG(" ... ");
-                for (int i = n_embd - 3; i < n_embd; i++) {
+                for (int i = n_embd_out - 3; i < n_embd_out; i++) {
                     if (params.embd_normalize == 0) {
-                        LOG("%6.0f ", emb[j * n_embd + i]);
+                        LOG("%6.0f ", emb[j * n_embd_out + i]);
                     } else {
-                        LOG("%9.6f ", emb[j * n_embd + i]);
+                        LOG("%9.6f ", emb[j * n_embd_out + i]);
                     }
                 }
                 LOG("\n");
@@ -320,9 +320,9 @@ int main(int argc, char ** argv) {
                 for (uint32_t i = 0; i < n_cls_out; i++) {
                     // NOTE: if you change this log - update the tests in ci/run.sh
                     if (n_cls_out == 1) {
-                        LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
+                        LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd_out]);
                     } else {
-                        LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
+                        LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd_out + i], cls_out_labels[i].c_str());
                     }
                 }
             }
@@ -330,11 +330,11 @@ int main(int argc, char ** argv) {
             // print the first part of the embeddings or for a single prompt, the full embedding
             for (int j = 0; j < n_prompts; j++) {
                 LOG("embedding %d: ", j);
-                for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
+                for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) {
                     if (params.embd_normalize == 0) {
-                        LOG("%6.0f ", emb[j * n_embd + i]);
+                        LOG("%6.0f ", emb[j * n_embd_out + i]);
                     } else {
-                        LOG("%9.6f ", emb[j * n_embd + i]);
+                        LOG("%9.6f ", emb[j * n_embd_out + i]);
                     }
                 }
                 LOG("\n");
@@ -350,7 +350,7 @@ int main(int argc, char ** argv) {
                 LOG("\n");
                 for (int i = 0; i < n_prompts; i++) {
                     for (int j = 0; j < n_prompts; j++) {
-                        float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
+                        float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
                         LOG("%6.2f ", sim);
                     }
                     LOG("%1.10s", prompts[i].c_str());
@@ -368,9 +368,9 @@ int main(int argc, char ** argv) {
             if (notArray) LOG("    {\n      \"object\": \"embedding\",\n      \"index\": %d,\n      \"embedding\": ",j);
             LOG("[");
             for (int i = 0;;) { // at least one iteration (n_embd > 0)
-                LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
+                LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]);
                 i++;
-                if (i < n_embd) LOG(","); else break;
+                if (i < n_embd_out) LOG(","); else break;
             }
             LOG(notArray ? "]\n    }" : "]");
             j++;
@@ -383,7 +383,7 @@ int main(int argc, char ** argv) {
             for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
                 LOG("    [");
                 for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
-                    float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
+                    float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
                     LOG("%6.2f", sim);
                     j++;
                     if (j < n_embd_count) LOG(", "); else break;
@@ -397,7 +397,7 @@ int main(int argc, char ** argv) {
 
         if (notArray) LOG("\n}\n");
     } else if (params.embd_out == "raw") {
-        print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize);
+        print_raw_embeddings(emb, n_embd_count, n_embd_out, model, pooling_type, params.embd_normalize);
     }
 
     LOG("\n");
index 5bcf0632677edbe1ac83d6bfc7a0feac3477a054..f71f772ab114489b362dcadc90b3c485802b296b 100644 (file)
@@ -161,9 +161,9 @@ int main(int argc, char ** argv) {
     std::vector<float> embd_out;
 
     if (embedding_mode) {
-        const int n_embd = llama_model_n_embd(model);
+        const int n_embd_out = llama_model_n_embd_out(model);
         const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
-        const int n_embeddings = n_embd * n_embd_count;
+        const int n_embeddings = n_embd_out * n_embd_count;
         float * embeddings;
         type = "-embeddings";
 
@@ -177,7 +177,7 @@ int main(int argc, char ** argv) {
             embeddings = llama_get_embeddings(ctx);
         }
 
-        printf("Embedding dimension: %d\n", n_embd);
+        printf("Embedding dimension: %d\n", n_embd_out);
         printf("\n");
 
         // Print embeddings in the specified format
@@ -185,16 +185,16 @@ int main(int argc, char ** argv) {
             printf("embedding %d: ", j);
 
             // Print first 3 values
-            for (int i = 0; i < 3 && i < n_embd; i++) {
-                printf("%9.6f ", embeddings[j * n_embd + i]);
+            for (int i = 0; i < 3 && i < n_embd_out; i++) {
+                printf("%9.6f ", embeddings[j * n_embd_out + i]);
             }
 
             printf(" ... ");
 
             // Print last 3 values
-            for (int i = n_embd - 3; i < n_embd; i++) {
+            for (int i = n_embd_out - 3; i < n_embd_out; i++) {
                 if (i >= 0) {
-                    printf("%9.6f ", embeddings[j * n_embd + i]);
+                    printf("%9.6f ", embeddings[j * n_embd_out + i]);
                 }
             }
 
index 8f92ff905786c9f17aa24bdcc1d85b0f7be5bc29..3f2afd4346e29c185ddebc90e6d918215861e781 100644 (file)
@@ -217,8 +217,8 @@ int main(int argc, char ** argv) {
     struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
 
     // allocate output
-    const int n_embd = llama_model_n_embd(model);
-    std::vector<float> embeddings(n_chunks * n_embd, 0);
+    const int n_embd_out = llama_model_n_embd_out(model);
+    std::vector<float> embeddings(n_chunks * n_embd_out, 0);
     float * emb = embeddings.data();
 
     // break into batches
@@ -232,8 +232,8 @@ int main(int argc, char ** argv) {
 
         // encode if at capacity
         if (batch.n_tokens + n_toks > n_batch || s >= llama_n_seq_max(ctx)) {
-            float * out = emb + p * n_embd;
-            batch_process(ctx, batch, out, s, n_embd);
+            float * out = emb + p * n_embd_out;
+            batch_process(ctx, batch, out, s, n_embd_out);
             common_batch_clear(batch);
             p += s;
             s = 0;
@@ -245,12 +245,12 @@ int main(int argc, char ** argv) {
     }
 
     // final batch
-    float * out = emb + p * n_embd;
-    batch_process(ctx, batch, out, s, n_embd);
+    float * out = emb + p * n_embd_out;
+    batch_process(ctx, batch, out, s, n_embd_out);
 
     // save embeddings to chunks
     for (int i = 0; i < n_chunks; i++) {
-        chunks[i].embedding = std::vector<float>(emb + i * n_embd, emb + (i + 1) * n_embd);
+        chunks[i].embedding = std::vector<float>(emb + i * n_embd_out, emb + (i + 1) * n_embd_out);
         // clear tokens as they are no longer needed
         chunks[i].tokens.clear();
     }
@@ -266,8 +266,8 @@ int main(int argc, char ** argv) {
 
         batch_add_seq(query_batch, query_tokens, 0);
 
-        std::vector<float> query_emb(n_embd, 0);
-        batch_process(ctx, query_batch, query_emb.data(), 1, n_embd);
+        std::vector<float> query_emb(n_embd_out, 0);
+        batch_process(ctx, query_batch, query_emb.data(), 1, n_embd_out);
 
         common_batch_clear(query_batch);
 
@@ -275,7 +275,7 @@ int main(int argc, char ** argv) {
         {
             std::vector<std::pair<int, float>> similarities;
             for (int i = 0; i < n_chunks; i++) {
-                float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd);
+                float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd_out);
                 similarities.push_back(std::make_pair(i, sim));
             }
 
index c8feca5679b9a2d26ef0ca7e3fc0b52da47e1ed5..64c227799f43e35bb31a0ceb4004c8d70b9d9422 100644 (file)
@@ -104,6 +104,7 @@ class Keys:
         VOCAB_SIZE                        = "{arch}.vocab_size"
         CONTEXT_LENGTH                    = "{arch}.context_length"
         EMBEDDING_LENGTH                  = "{arch}.embedding_length"
+        EMBEDDING_LENGTH_OUT              = "{arch}.embedding_length_out"
         FEATURES_LENGTH                   = "{arch}.features_length"
         BLOCK_COUNT                       = "{arch}.block_count"
         LEADING_DENSE_BLOCK_COUNT         = "{arch}.leading_dense_block_count"
@@ -3038,6 +3039,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ATTN_V,
         MODEL_TENSOR.ATTN_OUT,
         MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.DENSE_2_OUT, # LFM2-ColBert-350M
     ],
     MODEL_ARCH.LFM2MOE: [
         MODEL_TENSOR.TOKEN_EMBD,
index 612a978e4c30c82dc667f75191d435299b51885d..a7506aa793450f2a8c189cdd2f5c3513773d2d59 100644 (file)
@@ -681,6 +681,9 @@ class GGUFWriter:
     def add_embedding_length(self, length: int) -> None:
         self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
 
+    def add_embedding_length_out(self, length: int) -> None:
+        self.add_uint32(Keys.LLM.EMBEDDING_LENGTH_OUT.format(arch=self.arch), length)
+
     def add_features_length(self, length: int) -> None:
         self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length)
 
index bf4ce5f927fc93866b94cdee6337e9a8e5f6c242..05cb6532542997be33d1359b20d6084075d89a5f 100644 (file)
@@ -535,6 +535,7 @@ extern "C" {
     LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_embd     (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
+    LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_layer    (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head     (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head_kv  (const struct llama_model * model);
index 93fed1a9a3cbea994fc764f9694e3fc46ebba96f..2ead965469a038e9158b172c8d770dfee50bb304 100644 (file)
@@ -152,6 +152,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_VOCAB_SIZE,                        "%s.vocab_size"                        },
     { LLM_KV_CONTEXT_LENGTH,                    "%s.context_length"                    },
     { LLM_KV_EMBEDDING_LENGTH,                  "%s.embedding_length"                  },
+    { LLM_KV_EMBEDDING_LENGTH_OUT,              "%s.embedding_length_out"              },
     { LLM_KV_FEATURES_LENGTH,                   "%s.features_length"                   },
     { LLM_KV_BLOCK_COUNT,                       "%s.block_count"                       },
     { LLM_KV_LEADING_DENSE_BLOCK_COUNT,         "%s.leading_dense_block_count"         },
@@ -2075,6 +2076,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
                 LLM_TENSOR_TOKEN_EMBD,
                 LLM_TENSOR_OUTPUT_NORM_LFM2,
                 LLM_TENSOR_OUTPUT,
+                LLM_TENSOR_DENSE_2_OUT,
             };
         case LLM_ARCH_LFM2MOE:
             return {
index 57e470a9f3885430c6d305139c2abccba759cab6..68ec6a18b185f0a4c944aa0ba32175bf31f6a1f8 100644 (file)
@@ -156,6 +156,7 @@ enum llm_kv {
     LLM_KV_VOCAB_SIZE,
     LLM_KV_CONTEXT_LENGTH,
     LLM_KV_EMBEDDING_LENGTH,
+    LLM_KV_EMBEDDING_LENGTH_OUT,
     LLM_KV_FEATURES_LENGTH,
     LLM_KV_BLOCK_COUNT,
     LLM_KV_LEADING_DENSE_BLOCK_COUNT,
index 9c2e1c17a3ad39ef2f87711a06d4152a98c4d75f..f220010a1b4cce8b422942d3492c8aa20bb451ad 100644 (file)
@@ -758,7 +758,8 @@ float * llama_context::get_embeddings_ith(int32_t i) {
             throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
         }
 
-        return embd + j*model.hparams.n_embd;
+        const uint32_t n_embd_out = model.hparams.get_n_embd_out();
+        return embd + j*n_embd_out;
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
@@ -1194,9 +1195,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
                 {
                     // extract token embeddings
                     GGML_ASSERT(embd != nullptr);
+                    const uint32_t n_embd_out = hparams.get_n_embd_out();
 
-                    GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
-                    ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
+                    GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
+                    ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
                 } break;
             case LLAMA_POOLING_TYPE_MEAN:
             case LLAMA_POOLING_TYPE_CLS:
@@ -1600,12 +1602,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
                     {
                         // extract token embeddings
                         GGML_ASSERT(embd != nullptr);
-                        float * embd_out = embd + n_outputs_prev*n_embd;
+                        const uint32_t n_embd_out = hparams.get_n_embd_out();
+                        float * embd_out = embd + n_outputs_prev*n_embd_out;
 
                         if (n_outputs) {
                             GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
-                            GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
-                            ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
+                            GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
+                            ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
                         }
                     } break;
                 case LLAMA_POOLING_TYPE_MEAN:
@@ -1730,9 +1733,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
 
     const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
 
-    const auto n_batch = cparams.n_batch;
-    const auto n_vocab = vocab.n_tokens();
-    const auto n_embd  = hparams.n_embd;
+    const auto n_batch    = cparams.n_batch;
+    const auto n_vocab    = vocab.n_tokens();
+    const auto n_embd_out = hparams.get_n_embd_out();
 
     bool has_logits = true;
     bool has_embd   = cparams.embeddings;
@@ -1773,7 +1776,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
 
     // Allocate CPU logits buffer only if needed by sequences in this batch
     logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
-    embd_size   = has_embd ? n_embd*n_outputs_max : 0;
+    embd_size   = has_embd ? n_embd_out*n_outputs_max : 0;
 
     // TODO: avoid this branching by working with the worst-case
     if (!has_sampling) {
index 86c547263835eb5d508de7fcd2b466ad65248214..374ff1ebf3a2a41d498b6a0476593272b0707518 100644 (file)
@@ -2071,14 +2071,18 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
 void llm_graph_context::build_dense_out(
     ggml_tensor * dense_2,
     ggml_tensor * dense_3) const {
-    if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
+    if (!cparams.embeddings || !(dense_2 || dense_3)) {
         return;
     }
     ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
     GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
 
-    cur = ggml_mul_mat(ctx0, dense_2, cur);
-    cur = ggml_mul_mat(ctx0, dense_3, cur);
+    if (dense_2) {
+        cur = ggml_mul_mat(ctx0, dense_2, cur);
+    }
+    if (dense_3) {
+        cur = ggml_mul_mat(ctx0, dense_3, cur);
+    }
     cb(cur, "result_embd_pooled", -1);
     res->t_embd_pooled = cur;
     ggml_build_forward_expand(gf, cur);
index fe1fa4341d4288c209bc2ccebb6141fd8c9a1e7c..c847ef91b7aa26b8a4e76b94675550c6293ece12 100644 (file)
@@ -72,6 +72,10 @@ uint32_t llama_hparams::n_embd_inp() const {
     return n_embd_inp;
 }
 
+uint32_t llama_hparams::get_n_embd_out() const {
+    return n_embd_out > 0 ? n_embd_out : n_embd;
+}
+
 uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
     const uint32_t n_head_kv = this->n_head_kv(il);
 
index fc5708fc4b01a64c7a497246ff8d55153bc2e54e..7ae3ec292efed1c8cd62f0dda95a868e44861c94 100644 (file)
@@ -162,6 +162,9 @@ struct llama_hparams {
     // for Classifiers
     uint32_t n_cls_out = 1;
 
+    // output embedding dimension (0 = use n_embd)
+    uint32_t n_embd_out = 0;
+
     // llama4 smallthinker
     uint32_t n_moe_layer_step        = 0;
     uint32_t n_no_rope_layer_step    = 4;
@@ -234,6 +237,9 @@ struct llama_hparams {
     // dimension of main + auxiliary input embeddings
     uint32_t n_embd_inp() const;
 
+    // dimension of output embeddings
+    uint32_t get_n_embd_out() const;
+
     // dimension of key embeddings across all k-v heads
     uint32_t n_embd_k_gqa(uint32_t il = 0) const;
 
index 563823dc35d8eef29b5e7b62589c6500face3a97..ae27c71ce2300ac5cc477aca2f47a9d86625ef90 100644 (file)
@@ -146,6 +146,9 @@ void llama_model_saver::add_kv_from_model() {
     add_kv(LLM_KV_VOCAB_SIZE,                        vocab.n_tokens());
     add_kv(LLM_KV_CONTEXT_LENGTH,                    hparams.n_ctx_train);
     add_kv(LLM_KV_EMBEDDING_LENGTH,                  hparams.n_embd);
+    if (hparams.n_embd_out > 0) {
+        add_kv(LLM_KV_EMBEDDING_LENGTH_OUT,          hparams.n_embd_out);
+    }
     add_kv(LLM_KV_BLOCK_COUNT,                       hparams.n_layer);
     add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT,         hparams.n_layer_dense_lead);
     add_kv(LLM_KV_FEED_FORWARD_LENGTH,               hparams.n_ff_arr, true);
index 28dcc2840f0a6239d0776f9439e4d24263c77f5b..04c48b5fd3f1e47e4f20b57a28d4129325a2f22c 100644 (file)
@@ -507,6 +507,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
 
     ml.get_key(LLM_KV_CONTEXT_LENGTH,          hparams.n_ctx_train);
     ml.get_key(LLM_KV_EMBEDDING_LENGTH,        hparams.n_embd);
+    ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT,    hparams.n_embd_out, false);
     ml.get_key(LLM_KV_BLOCK_COUNT,             hparams.n_layer);
     ml.get_key(LLM_KV_EXPERT_COUNT,            hparams.n_expert,        false);
     ml.get_key(LLM_KV_EXPERT_USED_COUNT,       hparams.n_expert_used,   false);
@@ -6469,6 +6470,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                             layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0);
                         }
                     }
+
+                    // for LFM2-ColBert-350M
+                    dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.get_n_embd_out()}, TENSOR_NOT_REQUIRED);
                 } break;
             case LLM_ARCH_SMALLTHINKER:
                 {
@@ -8003,6 +8007,10 @@ int32_t llama_model_n_embd_inp(const llama_model * model) {
     return model->hparams.n_embd_inp();
 }
 
+int32_t llama_model_n_embd_out(const llama_model * model) {
+    return model->hparams.get_n_embd_out();
+}
+
 int32_t llama_model_n_layer(const llama_model * model) {
     return model->hparams.n_layer;
 }
index 5a6223b29cb00e398a9a880cdec9b8d4d7fa3193..33635a158664ce8d0237a3144a98a29f4fb895bd 100644 (file)
@@ -1505,9 +1505,9 @@ private:
         res->n_tokens  = slot.task->n_tokens();
         res->res_type  = slot.task->params.res_type;
 
-        const int n_embd = llama_model_n_embd(model);
+        const int n_embd_out = llama_model_n_embd_out(model);
 
-        std::vector<float> embd_res(n_embd, 0.0f);
+        std::vector<float> embd_res(n_embd_out, 0.0f);
 
         for (int i = 0; i < batch.n_tokens; ++i) {
             if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
@@ -1524,18 +1524,18 @@ private:
             if (embd == nullptr) {
                 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
 
-                res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
+                res->embedding.push_back(std::vector<float>(n_embd_out, 0.0f));
                 continue;
             }
 
             // normalize only when there is pooling
             if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
-                common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
+                common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize);
                 res->embedding.push_back(embd_res);
                 break;
             }
 
-            res->embedding.emplace_back(embd, embd + n_embd);
+            res->embedding.emplace_back(embd, embd + n_embd_out);
         }
 
         SLT_DBG(slot, "%s", "sending embeddings\n");