]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model-conversion : add support for SentenceTransformers (#16387)
authorDaniel Bevenius <redacted>
Thu, 9 Oct 2025 12:35:22 +0000 (14:35 +0200)
committerGitHub <redacted>
Thu, 9 Oct 2025 12:35:22 +0000 (14:35 +0200)
* model-conversion : add support for SentenceTransformers

This commit adds support for models that use SentenceTransformer layers.

The motivation for this is that if converted model includes any of the
numbered layers specified in the original models repository then these
changes enable these models to be used and verified. Currently the
model-conversion only support the base model output without any of
the additional transformation layers.

Usage:
Convert the model that also includes the SentenceTransformer layers:
```console
(venv) $ export EMBEDDING_MODEL_PATH="~/google/embeddinggemma-300M"
(venv) make embedding-convert-model
```

Verify the produced embeddings from the converted model against the
original model embeddings:
```console
(venv) make embedding-verify-logits-st
```

The original model can be run using SentenceTransformer:
```console
(venv) make embedding-run-original-model-st
```

Run the converted model using "SentenceTransformer" layers whic
enables pooling and normalization:
```console
(venv) make embedding-run-converted-model-st
```

* add model-conversion example requirements

* add support for -st flag in embedding model conversion

This commit add support for the -st flag in the embedding model
conversion script. This will enable models to be converted using
sentence transformers dense layers.

examples/model-conversion/Makefile
examples/model-conversion/README.md
examples/model-conversion/logits.cpp
examples/model-conversion/requirements.txt
examples/model-conversion/scripts/embedding/convert-model.sh
examples/model-conversion/scripts/embedding/run-converted-model.sh
examples/model-conversion/scripts/embedding/run-original-model.py
examples/model-conversion/scripts/utils/semantic_check.py
requirements/requirements-all.txt

index f0867cfe46c3a52d1e42e918b94405342ff3b48d..25b0514b29bc5705ed75040aa6c1030b0a2560da 100644 (file)
@@ -116,20 +116,39 @@ embedding-convert-model:
        METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
        ./scripts/embedding/convert-model.sh
 
+embedding-convert-model-st:
+       $(call validate_embedding_model_path,embedding-convert-model-st)
+       @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
+       METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
+       ./scripts/embedding/convert-model.sh -st
+
 embedding-run-original-model:
        $(call validate_embedding_model_path,embedding-run-original-model)
        @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
+       USE_SENTENCE_TRANSFORMERS="$(USE_SENTENCE_TRANSFORMERS)" \
        ./scripts/embedding/run-original-model.py \
-       $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
+       $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
+       $(if $(USE_SENTENCE_TRANSFORMERS),--use-sentence-transformers)
+
+embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1
+embedding-run-original-model-st: embedding-run-original-model
 
 embedding-run-converted-model:
        @./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \
-       $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
+       $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
+       $(if $(USE_POOLING),--pooling)
+
+embedding-run-converted-model-st: USE_POOLING=1
+embedding-run-converted-model-st: embedding-run-converted-model
 
 embedding-verify-logits: embedding-run-original-model embedding-run-converted-model
        @./scripts/embedding/compare-embeddings-logits.sh \
        $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
 
+embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st
+       @./scripts/embedding/compare-embeddings-logits.sh \
+       $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
+
 embedding-inspect-original-model:
        $(call validate_embedding_model_path,embedding-inspect-original-model)
        @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH}
index e95e05cd377cccb69f8f4f2b90a30fad9a6a7248..05d95d588bae7b30f481bc88326a394571678c8f 100644 (file)
@@ -189,6 +189,23 @@ This command will save two files to the `data` directory, one is a binary
 file containing logits which will be used for comparison with the converted
 model, and the other is a text file which allows for manual visual inspection.
 
+#### Using SentenceTransformer with numbered layers
+For models that have numbered SentenceTransformer layers (01_Pooling, 02_Dense,
+03_Dense, 04_Normalize), use the `-st` targets to apply all these layers:
+
+```console
+# Run original model with SentenceTransformer (applies all numbered layers)
+(venv) $ make embedding-run-original-model-st
+
+# Run converted model with pooling enabled
+(venv) $ make embedding-run-converted-model-st
+```
+
+This will use the SentenceTransformer library to load and run the model, which
+automatically applies all the numbered layers in the correct order. This is
+particularly useful when comparing with models that should include these
+additional transformation layers beyond just the base model output.
+
 ### Model conversion
 After updates have been made to [gguf-py](../../gguf-py) to add support for the
 new model the model can be converted to GGUF format using the following command:
@@ -208,6 +225,13 @@ was done manually in the previous steps) and compare the logits:
 (venv) $ make embedding-verify-logits
 ```
 
+For models with SentenceTransformer layers, use the `-st` verification target:
+```console
+(venv) $ make embedding-verify-logits-st
+```
+This convenience target automatically runs both the original model with SentenceTransformer
+and the converted model with pooling enabled, then compares the results.
+
 ### llama-server verification
 To verify that the converted model works with llama-server, the following
 command can be used:
index 6dc334189f4be6c81fc716c1a956794ae2c974f0..bbd095e6034cc358839598fc3c2b1dbf4b06e5f2 100644 (file)
@@ -1,4 +1,7 @@
 #include "llama.h"
+#include "common.h"
+
+
 #include <cstdio>
 #include <cstring>
 #include <string>
@@ -8,7 +11,10 @@
 
 static void print_usage(int, char ** argv) {
     printf("\nexample usage:\n");
-    printf("\n    %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [prompt]\n", argv[0]);
+    printf("\n    %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm <norm>] [prompt]\n", argv[0]);
+    printf("\n");
+    printf("  -embd-norm: normalization type for pooled embeddings (default: 2)\n");
+    printf("              -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n");
     printf("\n");
 }
 
@@ -17,6 +23,8 @@ int main(int argc, char ** argv) {
     std::string prompt = "Hello, my name is";
     int ngl = 0;
     bool embedding_mode = false;
+    bool pooling_enabled = false;
+    int32_t embd_norm = 2;  // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
 
     {
         int i = 1;
@@ -41,9 +49,13 @@ int main(int argc, char ** argv) {
                     return 1;
                 }
             } else if (strcmp(argv[i], "-embd-mode") == 0) {
+                embedding_mode = true;
+            } else if (strcmp(argv[i], "-pooling") == 0) {
+                pooling_enabled = true;
+            } else if (strcmp(argv[i], "-embd-norm") == 0) {
                 if (i + 1 < argc) {
                     try {
-                        embedding_mode = true;
+                        embd_norm = std::stoi(argv[++i]);
                     } catch (...) {
                         print_usage(argc, argv);
                         return 1;
@@ -112,7 +124,7 @@ int main(int argc, char ** argv) {
     ctx_params.no_perf = false;
     if (embedding_mode) {
         ctx_params.embeddings = true;
-        ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
+        ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE;
         ctx_params.n_ubatch = ctx_params.n_batch;
     }
 
@@ -143,17 +155,27 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    float * logits;
-    int n_logits;
+    float * data_ptr;
+    int data_size;
     const char * type;
+    std::vector<float> embd_out;
 
     if (embedding_mode) {
-        logits = llama_get_embeddings(ctx);
-        n_logits = llama_model_n_embd(model) * batch.n_tokens;
+        const int n_embd = llama_model_n_embd(model);
+        const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
+        const int n_embeddings = n_embd * n_embd_count;
+        float * embeddings;
         type = "-embeddings";
 
-        const int n_embd = llama_model_n_embd(model);
-        const int n_embd_count = batch.n_tokens;
+        if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
+            embeddings = llama_get_embeddings_seq(ctx, 0);
+            embd_out.resize(n_embeddings);
+            printf("Normalizing embeddings using norm: %d\n", embd_norm);
+            common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm);
+            embeddings = embd_out.data();
+        } else {
+            embeddings = llama_get_embeddings(ctx);
+        }
 
         printf("Embedding dimension: %d\n", n_embd);
         printf("\n");
@@ -164,7 +186,7 @@ int main(int argc, char ** argv) {
 
             // Print first 3 values
             for (int i = 0; i < 3 && i < n_embd; i++) {
-                printf("%9.6f ", logits[j * n_embd + i]);
+                printf("%9.6f ", embeddings[j * n_embd + i]);
             }
 
             printf(" ... ");
@@ -172,7 +194,7 @@ int main(int argc, char ** argv) {
             // Print last 3 values
             for (int i = n_embd - 3; i < n_embd; i++) {
                 if (i >= 0) {
-                    printf("%9.6f ", logits[j * n_embd + i]);
+                    printf("%9.6f ", embeddings[j * n_embd + i]);
                 }
             }
 
@@ -180,27 +202,33 @@ int main(int argc, char ** argv) {
         }
         printf("\n");
 
-        printf("Embeddings size: %d\n", n_logits);
+        printf("Embeddings size: %d\n", n_embeddings);
+
+        data_ptr = embeddings;
+        data_size = n_embeddings;
     } else {
-        logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
-        n_logits = llama_vocab_n_tokens(vocab);
+        float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
+        const int n_logits = llama_vocab_n_tokens(vocab);
         type = "";
         printf("Vocab size: %d\n", n_logits);
+
+        data_ptr = logits;
+        data_size = n_logits;
     }
 
     std::filesystem::create_directory("data");
 
-    // Save logits to binary file
+    // Save data to binary file
     char bin_filename[512];
     snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type);
-    printf("Saving logits to %s\n", bin_filename);
+    printf("Saving data to %s\n", bin_filename);
 
     FILE * f = fopen(bin_filename, "wb");
     if (f == NULL) {
         fprintf(stderr, "%s: error: failed to open binary output file\n", __func__);
         return 1;
     }
-    fwrite(logits, sizeof(float), n_logits, f);
+    fwrite(data_ptr, sizeof(float), data_size, f);
     fclose(f);
 
     // Also save as text for debugging
@@ -211,27 +239,27 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "%s: error: failed to open text output file\n", __func__);
         return 1;
     }
-    for (int i = 0; i < n_logits; i++) {
-        fprintf(f, "%d: %.6f\n", i, logits[i]);
+    for (int i = 0; i < data_size; i++) {
+        fprintf(f, "%d: %.6f\n", i, data_ptr[i]);
     }
     fclose(f);
 
     if (!embedding_mode) {
         printf("First 10 logits: ");
-        for (int i = 0; i < 10 && i < n_logits; i++) {
-            printf("%.6f ", logits[i]);
+        for (int i = 0; i < 10 && i < data_size; i++) {
+            printf("%.6f ", data_ptr[i]);
         }
         printf("\n");
 
         printf("Last 10 logits: ");
-        for (int i = n_logits - 10; i < n_logits; i++) {
-            if (i >= 0) printf("%.6f ", logits[i]);
+        for (int i = data_size - 10; i < data_size; i++) {
+            if (i >= 0) printf("%.6f ", data_ptr[i]);
         }
         printf("\n\n");
     }
 
-    printf("Logits saved to %s\n", bin_filename);
-    printf("Logits saved to %s\n", txt_filename);
+    printf("Data saved to %s\n", bin_filename);
+    printf("Data saved to %s\n", txt_filename);
 
     llama_free(ctx);
     llama_model_free(model);
index ac9f69e10bcc9762111b980fc4dd8d5ba12fbd13..229b2ec75b75b8fb0f564fa7b8c85fdae6ff8bd0 100644 (file)
@@ -4,3 +4,4 @@ torchvision
 transformers
 huggingface-hub
 accelerate
+sentence-transformers
index 0929e42413e6723eb827204dc2fd770e2a3c4075..9926350c072b243b9ac1f935eef28e497efa7cf6 100755 (executable)
@@ -2,6 +2,21 @@
 
 set -e
 
+# Parse command line arguments
+SENTENCE_TRANSFORMERS=""
+while [[ $# -gt 0 ]]; do
+    case $1 in
+        -st|--sentence-transformers)
+            SENTENCE_TRANSFORMERS="--sentence-transformers-dense-modules"
+            shift
+            ;;
+        *)
+            echo "Unknown option: $1"
+            exit 1
+            ;;
+    esac
+done
+
 MODEL_NAME="${MODEL_NAME:-$(basename "$EMBEDDING_MODEL_PATH")}"
 OUTPUT_DIR="${OUTPUT_DIR:-../../models}"
 TYPE="${OUTTYPE:-f16}"
@@ -15,7 +30,8 @@ echo "Converted model path:: ${CONVERTED_MODEL}"
 python ../../convert_hf_to_gguf.py --verbose \
     ${EMBEDDING_MODEL_PATH} \
     --outfile ${CONVERTED_MODEL} \
-    --outtype ${TYPE}
+    --outtype ${TYPE} \
+    ${SENTENCE_TRANSFORMERS}
 
 echo ""
 echo "The environment variable CONVERTED_EMBEDDING MODEL can be set to this path using:"
index f3e26766320700fecec97bc619ffe2c8d5db25ec..0f490e6c3b20aac378907f65ec4783eb995697a2 100755 (executable)
@@ -5,6 +5,7 @@ set -e
 # Parse command line arguments
 CONVERTED_MODEL=""
 PROMPTS_FILE=""
+USE_POOLING=""
 
 while [[ $# -gt 0 ]]; do
     case $1 in
@@ -12,6 +13,10 @@ while [[ $# -gt 0 ]]; do
             PROMPTS_FILE="$2"
             shift 2
             ;;
+        --pooling)
+            USE_POOLING="1"
+            shift
+            ;;
         *)
             if [ -z "$CONVERTED_MODEL" ]; then
                 CONVERTED_MODEL="$1"
@@ -47,4 +52,8 @@ echo $CONVERTED_MODEL
 
 cmake --build ../../build --target llama-logits -j8
 # TODO: update logits.cpp to accept a --file/-f option for the prompt
-../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
+if [ -n "$USE_POOLING" ]; then
+    ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT"
+else
+    ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
+fi
index 4a3e162413fa67cee1dbd54f1d708ab52a2c2277..640e200a97dc38157cc0b1f23b516fe6dfefd5f2 100755 (executable)
@@ -14,6 +14,8 @@ unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
 parser = argparse.ArgumentParser(description='Process model with specified path')
 parser.add_argument('--model-path', '-m', help='Path to the model')
 parser.add_argument('--prompts-file', '-p', help='Path to file containing prompts (one per line)')
+parser.add_argument('--use-sentence-transformers', action='store_true',
+                    help='Use SentenceTransformer to apply all numbered layers (01_Pooling, 02_Dense, 03_Dense, 04_Normalize)')
 args = parser.parse_args()
 
 def read_prompt_from_file(file_path):
@@ -31,41 +33,52 @@ model_path = os.environ.get('EMBEDDING_MODEL_PATH', args.model_path)
 if model_path is None:
     parser.error("Model path must be specified either via --model-path argument or EMBEDDING_MODEL_PATH environment variable")
 
-tokenizer = AutoTokenizer.from_pretrained(model_path)
+# Determine if we should use SentenceTransformer
+use_sentence_transformers = args.use_sentence_transformers or os.environ.get('USE_SENTENCE_TRANSFORMERS', '').lower() in ('1', 'true', 'yes')
 
-config = AutoConfig.from_pretrained(model_path)
-
-# This can be used to override the sliding window size for manual testing. This
-# can be useful to verify the sliding window attention mask in the original model
-# and compare it with the converted .gguf model.
-if hasattr(config, 'sliding_window'):
-    original_sliding_window = config.sliding_window
-    #original_sliding_window = 6
-    print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}")
-
-print(f"Using unreleased model: {unreleased_model_name}")
-if unreleased_model_name:
-    model_name_lower = unreleased_model_name.lower()
-    unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
-    class_name = f"{unreleased_model_name}Model"
-    print(f"Importing unreleased model module: {unreleased_module_path}")
-
-    try:
-        model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
-        model = model_class.from_pretrained(model_path, config=config)
-    except (ImportError, AttributeError) as e:
-        print(f"Failed to import or load model: {e}")
-        exit(1)
+if use_sentence_transformers:
+    from sentence_transformers import SentenceTransformer
+    print("Using SentenceTransformer to apply all numbered layers")
+    model = SentenceTransformer(model_path)
+    tokenizer = model.tokenizer
+    config = model[0].auto_model.config  # type: ignore
 else:
-    model = AutoModel.from_pretrained(model_path, config=config)
-print(f"Model class: {type(model)}")
-print(f"Model file: {type(model).__module__}")
+    tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+    config = AutoConfig.from_pretrained(model_path)
+
+    # This can be used to override the sliding window size for manual testing. This
+    # can be useful to verify the sliding window attention mask in the original model
+    # and compare it with the converted .gguf model.
+    if hasattr(config, 'sliding_window'):
+        original_sliding_window = config.sliding_window
+        #original_sliding_window = 6
+        print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}")
+
+    print(f"Using unreleased model: {unreleased_model_name}")
+    if unreleased_model_name:
+        model_name_lower = unreleased_model_name.lower()
+        unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
+        class_name = f"{unreleased_model_name}Model"
+        print(f"Importing unreleased model module: {unreleased_module_path}")
+
+        try:
+            model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
+            model = model_class.from_pretrained(model_path, config=config)
+        except (ImportError, AttributeError) as e:
+            print(f"Failed to import or load model: {e}")
+            exit(1)
+    else:
+        model = AutoModel.from_pretrained(model_path, config=config)
+    print(f"Model class: {type(model)}")
+    print(f"Model file: {type(model).__module__}")
 
 # Verify the model is using the correct sliding window
-if hasattr(model.config, 'sliding_window'):
-    print(f"Model's sliding_window: {model.config.sliding_window}")
-else:
-    print("Model config does not have sliding_window attribute")
+if not use_sentence_transformers:
+    if hasattr(model.config, 'sliding_window'):  # type: ignore
+        print(f"Model's sliding_window: {model.config.sliding_window}")  # type: ignore
+    else:
+        print("Model config does not have sliding_window attribute")
 
 model_name = os.path.basename(model_path)
 
@@ -75,34 +88,56 @@ if args.prompts_file:
 else:
     texts = ["Hello world today"]
 
-encoded = tokenizer(
-    texts,
-    padding=True,
-    truncation=True,
-    return_tensors="pt"
-)
-
-tokens = encoded['input_ids'][0]
-token_strings = tokenizer.convert_ids_to_tokens(tokens)
-for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
-    print(f"{token_id:6d} -> '{token_str}'")
-
 with torch.no_grad():
-    outputs = model(**encoded)
-    hidden_states = outputs.last_hidden_state  # Shape: [batch_size, seq_len, hidden_size]
-
-    # Extract embeddings for each token (matching LLAMA_POOLING_TYPE_NONE behavior)
-    all_embeddings = hidden_states[0].cpu().numpy()  # Shape: [seq_len, hidden_size]
-
-    print(f"Hidden states shape: {hidden_states.shape}")
-    print(f"All embeddings shape: {all_embeddings.shape}")
-    print(f"Embedding dimension: {all_embeddings.shape[1]}")
-
-    # Print embeddings exactly like embedding.cpp does for LLAMA_POOLING_TYPE_NONE
-    n_embd = all_embeddings.shape[1]
-    n_embd_count = all_embeddings.shape[0]
-
-    print()  # Empty line to match C++ output
+    if use_sentence_transformers:
+        embeddings = model.encode(texts, convert_to_numpy=True)
+        all_embeddings = embeddings  # Shape: [batch_size, hidden_size]
+
+        encoded = tokenizer(
+            texts,
+            padding=True,
+            truncation=True,
+            return_tensors="pt"
+        )
+        tokens = encoded['input_ids'][0]
+        token_strings = tokenizer.convert_ids_to_tokens(tokens)
+        for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
+            print(f"{token_id:6d} -> '{token_str}'")
+
+        print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}")
+        print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}")  # type: ignore
+    else:
+        # Standard approach: use base model output only
+        encoded = tokenizer(
+            texts,
+            padding=True,
+            truncation=True,
+            return_tensors="pt"
+        )
+
+        tokens = encoded['input_ids'][0]
+        token_strings = tokenizer.convert_ids_to_tokens(tokens)
+        for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
+            print(f"{token_id:6d} -> '{token_str}'")
+
+        outputs = model(**encoded)
+        hidden_states = outputs.last_hidden_state  # Shape: [batch_size, seq_len, hidden_size]
+
+        all_embeddings = hidden_states[0].cpu().numpy()  # Shape: [seq_len, hidden_size]
+
+        print(f"Hidden states shape: {hidden_states.shape}")
+        print(f"All embeddings shape: {all_embeddings.shape}")
+        print(f"Embedding dimension: {all_embeddings.shape[1]}")
+
+    if len(all_embeddings.shape) == 1:
+        n_embd = all_embeddings.shape[0]  # type: ignore
+        n_embd_count = 1
+        all_embeddings = all_embeddings.reshape(1, -1)
+    else:
+        n_embd = all_embeddings.shape[1]  # type: ignore
+        n_embd_count = all_embeddings.shape[0]  # type: ignore
+
+    print()
 
     for j in range(n_embd_count):
         embedding = all_embeddings[j]
@@ -120,29 +155,23 @@ with torch.no_grad():
 
         print()  # New line
 
-    print()  # Final empty line to match C++ output
+    print()
 
     data_dir = Path("data")
     data_dir.mkdir(exist_ok=True)
     bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
     txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
 
-    # Save all embeddings flattened (matching what embedding.cpp would save if it did)
     flattened_embeddings = all_embeddings.flatten()
     flattened_embeddings.astype(np.float32).tofile(bin_filename)
 
     with open(txt_filename, "w") as f:
-        f.write(f"# Model class: {model_name}\n")
-        f.write(f"# Tokens: {token_strings}\n")
-        f.write(f"# Shape: {all_embeddings.shape}\n")
-        f.write(f"# n_embd_count: {n_embd_count}, n_embd: {n_embd}\n\n")
-
+        idx = 0
         for j in range(n_embd_count):
-            f.write(f"# Token {j} ({token_strings[j]}):\n")
-            for i, value in enumerate(all_embeddings[j]):
-                f.write(f"{j}_{i}: {value:.6f}\n")
-            f.write("\n")
-    print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} tokens × {n_embd} dimensions)")
+            for value in all_embeddings[j]:
+                f.write(f"{idx}: {value:.6f}\n")
+                idx += 1
+    print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)")
     print("")
     print(f"Saved bin embeddings to: {bin_filename}")
     print(f"Saved txt embeddings to: {txt_filename}")
index 7fd417bceaa8b4423774af4077be49e51c0d1586..2ac8b6b7b42cb0e42f60f3c74a4a3dd6714e947f 100644 (file)
@@ -35,7 +35,11 @@ def cosine_similarity(a, b=None):
 
 def load_embeddings_from_file(filename, n_tokens, n_embd):
     embeddings = np.fromfile(filename, dtype=np.float32)
-    return embeddings.reshape(n_tokens, n_embd)
+    # Check if this is pooled (single embedding) or per-token embeddings
+    if len(embeddings) == n_embd:
+        return embeddings.reshape(1, n_embd)
+    else:
+        return embeddings.reshape(n_tokens, n_embd)
 
 def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
     np.set_printoptions(suppress=True, precision=6)
@@ -48,58 +52,83 @@ def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
     print(f"Embeddings shape: Python {python_emb.shape}, llama.cpp {cpp_emb.shape}")
 
     n_tokens = len(tokens)
+    is_pooled = python_emb.shape[0] == 1
+
+    if is_pooled:
+        print(f"\n[Pooled Embeddings Mode - comparing single sentence embeddings]")
 
-    # 1. Direct embedding comparison
-    print(f"\n1. Raw Embedding Magnitude Comparison:")
-    # Check if the distance of each token embedding from the origin and compare
-    # if the vectors are on the same "sphere". This does not tell us about
-    # direction (meaning of the token embedding), just magnitude.
-    for i in range(n_tokens):
-        py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings
-        cpp_mag = np.linalg.norm(cpp_emb[i])   # calculate standard euclidean norm for llama.cpp embeddings
+        # 1. Direct embedding comparison for pooled embeddings
+        print(f"\n1. Raw Embedding Magnitude Comparison:")
+        py_mag = np.linalg.norm(python_emb[0])
+        cpp_mag = np.linalg.norm(cpp_emb[0])
         ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
-        print(f"   Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
-
-    # 2. Cosine similarity between tokens within each model
-    # Here we check the direction of token embeddings to see if the have the
-    # same meaning (similarity). This is done by calculating cosine similarity
-    # of a pair of token embeddings within each model.
-    print(f"\n2. Within-Model Token Similarities:")
-    print("   Python model:")
-    for i in range(n_tokens):
-        for j in range(i+1, n_tokens):
-            sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0]
-            print(f"     {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
-
-    print("   llama.cpp model:")
-    for i in range(n_tokens):
-        for j in range(i+1, n_tokens):
-            sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0]
-            print(f"     {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
-
-    # 3. Cross-model similarity (same token position)
-    print(f"\n3. Cross-Model Same-Token Similarities:")
-    for i in range(n_tokens):
-        sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0]
-        print(f"   Token {i} ({tokens[i]}): {sim:.4f}")
-
-    # 4. Similarity matrix comparison
-    print(f"\n4. Similarity Matrix Differences:")
-    py_sim_matrix = cosine_similarity(python_emb)
-    cpp_sim_matrix = cosine_similarity(cpp_emb)
-    diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix)
-
-    print(f"   Max difference: {np.max(diff_matrix):.4f}")
-    print(f"   Mean difference: {np.mean(diff_matrix):.4f}")
-    print(f"   RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}")
-
-    return {
-        'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)],
-        'similarity_matrix_diff': diff_matrix,
-        'max_diff': np.max(diff_matrix),
-        'mean_diff': np.mean(diff_matrix),
-        'rms_diff': np.sqrt(np.mean(diff_matrix**2))
-    }
+        print(f"   Pooled embedding: Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
+
+        # 2. Cross-model similarity for pooled embeddings
+        print(f"\n2. Cross-Model Pooled Embedding Similarity:")
+        sim = cosine_similarity([python_emb[0]], [cpp_emb[0]])[0][0]
+        print(f"   Cosine similarity: {sim:.6f}")
+
+        return {
+            'cross_model_similarities': [sim],
+            'similarity_matrix_diff': np.array([[0.0]]),
+            'max_diff': 0.0,
+            'mean_diff': 0.0,
+            'rms_diff': 0.0
+        }
+    else:
+        # Original per-token comparison logic
+        # 1. Direct embedding comparison
+        print(f"\n1. Raw Embedding Magnitude Comparison:")
+        # Check if the distance of each token embedding from the origin and compare
+        # if the vectors are on the same "sphere". This does not tell us about
+        # direction (meaning of the token embedding), just magnitude.
+        for i in range(n_tokens):
+            py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings
+            cpp_mag = np.linalg.norm(cpp_emb[i])   # calculate standard euclidean norm for llama.cpp embeddings
+            ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
+            print(f"   Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
+
+        # 2. Cosine similarity between tokens within each model
+        # Here we check the direction of token embeddings to see if the have the
+        # same meaning (similarity). This is done by calculating cosine similarity
+        # of a pair of token embeddings within each model.
+        print(f"\n2. Within-Model Token Similarities:")
+        print("   Python model:")
+        for i in range(n_tokens):
+            for j in range(i+1, n_tokens):
+                sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0]
+                print(f"     {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
+
+        print("   llama.cpp model:")
+        for i in range(n_tokens):
+            for j in range(i+1, n_tokens):
+                sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0]
+                print(f"     {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
+
+        # 3. Cross-model similarity (same token position)
+        print(f"\n3. Cross-Model Same-Token Similarities:")
+        for i in range(n_tokens):
+            sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0]
+            print(f"   Token {i} ({tokens[i]}): {sim:.4f}")
+
+        # 4. Similarity matrix comparison
+        print(f"\n4. Similarity Matrix Differences:")
+        py_sim_matrix = cosine_similarity(python_emb)
+        cpp_sim_matrix = cosine_similarity(cpp_emb)
+        diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix)
+
+        print(f"   Max difference: {np.max(diff_matrix):.4f}")
+        print(f"   Mean difference: {np.mean(diff_matrix):.4f}")
+        print(f"   RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}")
+
+        return {
+            'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)],
+            'similarity_matrix_diff': diff_matrix,
+            'max_diff': np.max(diff_matrix),
+            'mean_diff': np.mean(diff_matrix),
+            'rms_diff': np.sqrt(np.mean(diff_matrix**2))
+        }
 
 def read_prompt_from_file(file_path):
     try:
index 56b6752ac0645600b3a20aa97b91e85457399511..6c6bea9490b4b0c7b390ecaa2640cfbb69ae7131 100644 (file)
@@ -14,3 +14,5 @@
 -r ./requirements-tool_bench.txt
 
 -r ./requirements-gguf_editor_gui.txt
+
+-r ../examples/model-conversion/requirements.txt