METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
./scripts/embedding/convert-model.sh
+embedding-convert-model-st:
+ $(call validate_embedding_model_path,embedding-convert-model-st)
+ @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
+ METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
+ ./scripts/embedding/convert-model.sh -st
+
embedding-run-original-model:
$(call validate_embedding_model_path,embedding-run-original-model)
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
+ USE_SENTENCE_TRANSFORMERS="$(USE_SENTENCE_TRANSFORMERS)" \
./scripts/embedding/run-original-model.py \
- $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
+ $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
+ $(if $(USE_SENTENCE_TRANSFORMERS),--use-sentence-transformers)
+
+embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1
+embedding-run-original-model-st: embedding-run-original-model
embedding-run-converted-model:
@./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \
- $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
+ $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
+ $(if $(USE_POOLING),--pooling)
+
+embedding-run-converted-model-st: USE_POOLING=1
+embedding-run-converted-model-st: embedding-run-converted-model
embedding-verify-logits: embedding-run-original-model embedding-run-converted-model
@./scripts/embedding/compare-embeddings-logits.sh \
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
+embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st
+ @./scripts/embedding/compare-embeddings-logits.sh \
+ $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
+
embedding-inspect-original-model:
$(call validate_embedding_model_path,embedding-inspect-original-model)
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH}
file containing logits which will be used for comparison with the converted
model, and the other is a text file which allows for manual visual inspection.
+#### Using SentenceTransformer with numbered layers
+For models that have numbered SentenceTransformer layers (01_Pooling, 02_Dense,
+03_Dense, 04_Normalize), use the `-st` targets to apply all these layers:
+
+```console
+# Run original model with SentenceTransformer (applies all numbered layers)
+(venv) $ make embedding-run-original-model-st
+
+# Run converted model with pooling enabled
+(venv) $ make embedding-run-converted-model-st
+```
+
+This will use the SentenceTransformer library to load and run the model, which
+automatically applies all the numbered layers in the correct order. This is
+particularly useful when comparing with models that should include these
+additional transformation layers beyond just the base model output.
+
### Model conversion
After updates have been made to [gguf-py](../../gguf-py) to add support for the
new model the model can be converted to GGUF format using the following command:
(venv) $ make embedding-verify-logits
```
+For models with SentenceTransformer layers, use the `-st` verification target:
+```console
+(venv) $ make embedding-verify-logits-st
+```
+This convenience target automatically runs both the original model with SentenceTransformer
+and the converted model with pooling enabled, then compares the results.
+
### llama-server verification
To verify that the converted model works with llama-server, the following
command can be used:
#include "llama.h"
+#include "common.h"
+
+
#include <cstdio>
#include <cstring>
#include <string>
static void print_usage(int, char ** argv) {
printf("\nexample usage:\n");
- printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [prompt]\n", argv[0]);
+ printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm <norm>] [prompt]\n", argv[0]);
+ printf("\n");
+ printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n");
+ printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n");
printf("\n");
}
std::string prompt = "Hello, my name is";
int ngl = 0;
bool embedding_mode = false;
+ bool pooling_enabled = false;
+ int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
{
int i = 1;
return 1;
}
} else if (strcmp(argv[i], "-embd-mode") == 0) {
+ embedding_mode = true;
+ } else if (strcmp(argv[i], "-pooling") == 0) {
+ pooling_enabled = true;
+ } else if (strcmp(argv[i], "-embd-norm") == 0) {
if (i + 1 < argc) {
try {
- embedding_mode = true;
+ embd_norm = std::stoi(argv[++i]);
} catch (...) {
print_usage(argc, argv);
return 1;
ctx_params.no_perf = false;
if (embedding_mode) {
ctx_params.embeddings = true;
- ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
+ ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE;
ctx_params.n_ubatch = ctx_params.n_batch;
}
return 1;
}
- float * logits;
- int n_logits;
+ float * data_ptr;
+ int data_size;
const char * type;
+ std::vector<float> embd_out;
if (embedding_mode) {
- logits = llama_get_embeddings(ctx);
- n_logits = llama_model_n_embd(model) * batch.n_tokens;
+ const int n_embd = llama_model_n_embd(model);
+ const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
+ const int n_embeddings = n_embd * n_embd_count;
+ float * embeddings;
type = "-embeddings";
- const int n_embd = llama_model_n_embd(model);
- const int n_embd_count = batch.n_tokens;
+ if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
+ embeddings = llama_get_embeddings_seq(ctx, 0);
+ embd_out.resize(n_embeddings);
+ printf("Normalizing embeddings using norm: %d\n", embd_norm);
+ common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm);
+ embeddings = embd_out.data();
+ } else {
+ embeddings = llama_get_embeddings(ctx);
+ }
printf("Embedding dimension: %d\n", n_embd);
printf("\n");
// Print first 3 values
for (int i = 0; i < 3 && i < n_embd; i++) {
- printf("%9.6f ", logits[j * n_embd + i]);
+ printf("%9.6f ", embeddings[j * n_embd + i]);
}
printf(" ... ");
// Print last 3 values
for (int i = n_embd - 3; i < n_embd; i++) {
if (i >= 0) {
- printf("%9.6f ", logits[j * n_embd + i]);
+ printf("%9.6f ", embeddings[j * n_embd + i]);
}
}
}
printf("\n");
- printf("Embeddings size: %d\n", n_logits);
+ printf("Embeddings size: %d\n", n_embeddings);
+
+ data_ptr = embeddings;
+ data_size = n_embeddings;
} else {
- logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
- n_logits = llama_vocab_n_tokens(vocab);
+ float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
+ const int n_logits = llama_vocab_n_tokens(vocab);
type = "";
printf("Vocab size: %d\n", n_logits);
+
+ data_ptr = logits;
+ data_size = n_logits;
}
std::filesystem::create_directory("data");
- // Save logits to binary file
+ // Save data to binary file
char bin_filename[512];
snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type);
- printf("Saving logits to %s\n", bin_filename);
+ printf("Saving data to %s\n", bin_filename);
FILE * f = fopen(bin_filename, "wb");
if (f == NULL) {
fprintf(stderr, "%s: error: failed to open binary output file\n", __func__);
return 1;
}
- fwrite(logits, sizeof(float), n_logits, f);
+ fwrite(data_ptr, sizeof(float), data_size, f);
fclose(f);
// Also save as text for debugging
fprintf(stderr, "%s: error: failed to open text output file\n", __func__);
return 1;
}
- for (int i = 0; i < n_logits; i++) {
- fprintf(f, "%d: %.6f\n", i, logits[i]);
+ for (int i = 0; i < data_size; i++) {
+ fprintf(f, "%d: %.6f\n", i, data_ptr[i]);
}
fclose(f);
if (!embedding_mode) {
printf("First 10 logits: ");
- for (int i = 0; i < 10 && i < n_logits; i++) {
- printf("%.6f ", logits[i]);
+ for (int i = 0; i < 10 && i < data_size; i++) {
+ printf("%.6f ", data_ptr[i]);
}
printf("\n");
printf("Last 10 logits: ");
- for (int i = n_logits - 10; i < n_logits; i++) {
- if (i >= 0) printf("%.6f ", logits[i]);
+ for (int i = data_size - 10; i < data_size; i++) {
+ if (i >= 0) printf("%.6f ", data_ptr[i]);
}
printf("\n\n");
}
- printf("Logits saved to %s\n", bin_filename);
- printf("Logits saved to %s\n", txt_filename);
+ printf("Data saved to %s\n", bin_filename);
+ printf("Data saved to %s\n", txt_filename);
llama_free(ctx);
llama_model_free(model);
transformers
huggingface-hub
accelerate
+sentence-transformers
set -e
+# Parse command line arguments
+SENTENCE_TRANSFORMERS=""
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ -st|--sentence-transformers)
+ SENTENCE_TRANSFORMERS="--sentence-transformers-dense-modules"
+ shift
+ ;;
+ *)
+ echo "Unknown option: $1"
+ exit 1
+ ;;
+ esac
+done
+
MODEL_NAME="${MODEL_NAME:-$(basename "$EMBEDDING_MODEL_PATH")}"
OUTPUT_DIR="${OUTPUT_DIR:-../../models}"
TYPE="${OUTTYPE:-f16}"
python ../../convert_hf_to_gguf.py --verbose \
${EMBEDDING_MODEL_PATH} \
--outfile ${CONVERTED_MODEL} \
- --outtype ${TYPE}
+ --outtype ${TYPE} \
+ ${SENTENCE_TRANSFORMERS}
echo ""
echo "The environment variable CONVERTED_EMBEDDING MODEL can be set to this path using:"
# Parse command line arguments
CONVERTED_MODEL=""
PROMPTS_FILE=""
+USE_POOLING=""
while [[ $# -gt 0 ]]; do
case $1 in
PROMPTS_FILE="$2"
shift 2
;;
+ --pooling)
+ USE_POOLING="1"
+ shift
+ ;;
*)
if [ -z "$CONVERTED_MODEL" ]; then
CONVERTED_MODEL="$1"
cmake --build ../../build --target llama-logits -j8
# TODO: update logits.cpp to accept a --file/-f option for the prompt
-../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
+if [ -n "$USE_POOLING" ]; then
+ ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT"
+else
+ ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
+fi
parser = argparse.ArgumentParser(description='Process model with specified path')
parser.add_argument('--model-path', '-m', help='Path to the model')
parser.add_argument('--prompts-file', '-p', help='Path to file containing prompts (one per line)')
+parser.add_argument('--use-sentence-transformers', action='store_true',
+ help='Use SentenceTransformer to apply all numbered layers (01_Pooling, 02_Dense, 03_Dense, 04_Normalize)')
args = parser.parse_args()
def read_prompt_from_file(file_path):
if model_path is None:
parser.error("Model path must be specified either via --model-path argument or EMBEDDING_MODEL_PATH environment variable")
-tokenizer = AutoTokenizer.from_pretrained(model_path)
+# Determine if we should use SentenceTransformer
+use_sentence_transformers = args.use_sentence_transformers or os.environ.get('USE_SENTENCE_TRANSFORMERS', '').lower() in ('1', 'true', 'yes')
-config = AutoConfig.from_pretrained(model_path)
-
-# This can be used to override the sliding window size for manual testing. This
-# can be useful to verify the sliding window attention mask in the original model
-# and compare it with the converted .gguf model.
-if hasattr(config, 'sliding_window'):
- original_sliding_window = config.sliding_window
- #original_sliding_window = 6
- print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}")
-
-print(f"Using unreleased model: {unreleased_model_name}")
-if unreleased_model_name:
- model_name_lower = unreleased_model_name.lower()
- unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
- class_name = f"{unreleased_model_name}Model"
- print(f"Importing unreleased model module: {unreleased_module_path}")
-
- try:
- model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
- model = model_class.from_pretrained(model_path, config=config)
- except (ImportError, AttributeError) as e:
- print(f"Failed to import or load model: {e}")
- exit(1)
+if use_sentence_transformers:
+ from sentence_transformers import SentenceTransformer
+ print("Using SentenceTransformer to apply all numbered layers")
+ model = SentenceTransformer(model_path)
+ tokenizer = model.tokenizer
+ config = model[0].auto_model.config # type: ignore
else:
- model = AutoModel.from_pretrained(model_path, config=config)
-print(f"Model class: {type(model)}")
-print(f"Model file: {type(model).__module__}")
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+ config = AutoConfig.from_pretrained(model_path)
+
+ # This can be used to override the sliding window size for manual testing. This
+ # can be useful to verify the sliding window attention mask in the original model
+ # and compare it with the converted .gguf model.
+ if hasattr(config, 'sliding_window'):
+ original_sliding_window = config.sliding_window
+ #original_sliding_window = 6
+ print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}")
+
+ print(f"Using unreleased model: {unreleased_model_name}")
+ if unreleased_model_name:
+ model_name_lower = unreleased_model_name.lower()
+ unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
+ class_name = f"{unreleased_model_name}Model"
+ print(f"Importing unreleased model module: {unreleased_module_path}")
+
+ try:
+ model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
+ model = model_class.from_pretrained(model_path, config=config)
+ except (ImportError, AttributeError) as e:
+ print(f"Failed to import or load model: {e}")
+ exit(1)
+ else:
+ model = AutoModel.from_pretrained(model_path, config=config)
+ print(f"Model class: {type(model)}")
+ print(f"Model file: {type(model).__module__}")
# Verify the model is using the correct sliding window
-if hasattr(model.config, 'sliding_window'):
- print(f"Model's sliding_window: {model.config.sliding_window}")
-else:
- print("Model config does not have sliding_window attribute")
+if not use_sentence_transformers:
+ if hasattr(model.config, 'sliding_window'): # type: ignore
+ print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore
+ else:
+ print("Model config does not have sliding_window attribute")
model_name = os.path.basename(model_path)
else:
texts = ["Hello world today"]
-encoded = tokenizer(
- texts,
- padding=True,
- truncation=True,
- return_tensors="pt"
-)
-
-tokens = encoded['input_ids'][0]
-token_strings = tokenizer.convert_ids_to_tokens(tokens)
-for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
- print(f"{token_id:6d} -> '{token_str}'")
-
with torch.no_grad():
- outputs = model(**encoded)
- hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size]
-
- # Extract embeddings for each token (matching LLAMA_POOLING_TYPE_NONE behavior)
- all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size]
-
- print(f"Hidden states shape: {hidden_states.shape}")
- print(f"All embeddings shape: {all_embeddings.shape}")
- print(f"Embedding dimension: {all_embeddings.shape[1]}")
-
- # Print embeddings exactly like embedding.cpp does for LLAMA_POOLING_TYPE_NONE
- n_embd = all_embeddings.shape[1]
- n_embd_count = all_embeddings.shape[0]
-
- print() # Empty line to match C++ output
+ if use_sentence_transformers:
+ embeddings = model.encode(texts, convert_to_numpy=True)
+ all_embeddings = embeddings # Shape: [batch_size, hidden_size]
+
+ encoded = tokenizer(
+ texts,
+ padding=True,
+ truncation=True,
+ return_tensors="pt"
+ )
+ tokens = encoded['input_ids'][0]
+ token_strings = tokenizer.convert_ids_to_tokens(tokens)
+ for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
+ print(f"{token_id:6d} -> '{token_str}'")
+
+ print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}")
+ print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore
+ else:
+ # Standard approach: use base model output only
+ encoded = tokenizer(
+ texts,
+ padding=True,
+ truncation=True,
+ return_tensors="pt"
+ )
+
+ tokens = encoded['input_ids'][0]
+ token_strings = tokenizer.convert_ids_to_tokens(tokens)
+ for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
+ print(f"{token_id:6d} -> '{token_str}'")
+
+ outputs = model(**encoded)
+ hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size]
+
+ all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size]
+
+ print(f"Hidden states shape: {hidden_states.shape}")
+ print(f"All embeddings shape: {all_embeddings.shape}")
+ print(f"Embedding dimension: {all_embeddings.shape[1]}")
+
+ if len(all_embeddings.shape) == 1:
+ n_embd = all_embeddings.shape[0] # type: ignore
+ n_embd_count = 1
+ all_embeddings = all_embeddings.reshape(1, -1)
+ else:
+ n_embd = all_embeddings.shape[1] # type: ignore
+ n_embd_count = all_embeddings.shape[0] # type: ignore
+
+ print()
for j in range(n_embd_count):
embedding = all_embeddings[j]
print() # New line
- print() # Final empty line to match C++ output
+ print()
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
- # Save all embeddings flattened (matching what embedding.cpp would save if it did)
flattened_embeddings = all_embeddings.flatten()
flattened_embeddings.astype(np.float32).tofile(bin_filename)
with open(txt_filename, "w") as f:
- f.write(f"# Model class: {model_name}\n")
- f.write(f"# Tokens: {token_strings}\n")
- f.write(f"# Shape: {all_embeddings.shape}\n")
- f.write(f"# n_embd_count: {n_embd_count}, n_embd: {n_embd}\n\n")
-
+ idx = 0
for j in range(n_embd_count):
- f.write(f"# Token {j} ({token_strings[j]}):\n")
- for i, value in enumerate(all_embeddings[j]):
- f.write(f"{j}_{i}: {value:.6f}\n")
- f.write("\n")
- print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} tokens × {n_embd} dimensions)")
+ for value in all_embeddings[j]:
+ f.write(f"{idx}: {value:.6f}\n")
+ idx += 1
+ print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)")
print("")
print(f"Saved bin embeddings to: {bin_filename}")
print(f"Saved txt embeddings to: {txt_filename}")
def load_embeddings_from_file(filename, n_tokens, n_embd):
embeddings = np.fromfile(filename, dtype=np.float32)
- return embeddings.reshape(n_tokens, n_embd)
+ # Check if this is pooled (single embedding) or per-token embeddings
+ if len(embeddings) == n_embd:
+ return embeddings.reshape(1, n_embd)
+ else:
+ return embeddings.reshape(n_tokens, n_embd)
def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
np.set_printoptions(suppress=True, precision=6)
print(f"Embeddings shape: Python {python_emb.shape}, llama.cpp {cpp_emb.shape}")
n_tokens = len(tokens)
+ is_pooled = python_emb.shape[0] == 1
+
+ if is_pooled:
+ print(f"\n[Pooled Embeddings Mode - comparing single sentence embeddings]")
- # 1. Direct embedding comparison
- print(f"\n1. Raw Embedding Magnitude Comparison:")
- # Check if the distance of each token embedding from the origin and compare
- # if the vectors are on the same "sphere". This does not tell us about
- # direction (meaning of the token embedding), just magnitude.
- for i in range(n_tokens):
- py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings
- cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings
+ # 1. Direct embedding comparison for pooled embeddings
+ print(f"\n1. Raw Embedding Magnitude Comparison:")
+ py_mag = np.linalg.norm(python_emb[0])
+ cpp_mag = np.linalg.norm(cpp_emb[0])
ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
- print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
-
- # 2. Cosine similarity between tokens within each model
- # Here we check the direction of token embeddings to see if the have the
- # same meaning (similarity). This is done by calculating cosine similarity
- # of a pair of token embeddings within each model.
- print(f"\n2. Within-Model Token Similarities:")
- print(" Python model:")
- for i in range(n_tokens):
- for j in range(i+1, n_tokens):
- sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0]
- print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
-
- print(" llama.cpp model:")
- for i in range(n_tokens):
- for j in range(i+1, n_tokens):
- sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0]
- print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
-
- # 3. Cross-model similarity (same token position)
- print(f"\n3. Cross-Model Same-Token Similarities:")
- for i in range(n_tokens):
- sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0]
- print(f" Token {i} ({tokens[i]}): {sim:.4f}")
-
- # 4. Similarity matrix comparison
- print(f"\n4. Similarity Matrix Differences:")
- py_sim_matrix = cosine_similarity(python_emb)
- cpp_sim_matrix = cosine_similarity(cpp_emb)
- diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix)
-
- print(f" Max difference: {np.max(diff_matrix):.4f}")
- print(f" Mean difference: {np.mean(diff_matrix):.4f}")
- print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}")
-
- return {
- 'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)],
- 'similarity_matrix_diff': diff_matrix,
- 'max_diff': np.max(diff_matrix),
- 'mean_diff': np.mean(diff_matrix),
- 'rms_diff': np.sqrt(np.mean(diff_matrix**2))
- }
+ print(f" Pooled embedding: Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
+
+ # 2. Cross-model similarity for pooled embeddings
+ print(f"\n2. Cross-Model Pooled Embedding Similarity:")
+ sim = cosine_similarity([python_emb[0]], [cpp_emb[0]])[0][0]
+ print(f" Cosine similarity: {sim:.6f}")
+
+ return {
+ 'cross_model_similarities': [sim],
+ 'similarity_matrix_diff': np.array([[0.0]]),
+ 'max_diff': 0.0,
+ 'mean_diff': 0.0,
+ 'rms_diff': 0.0
+ }
+ else:
+ # Original per-token comparison logic
+ # 1. Direct embedding comparison
+ print(f"\n1. Raw Embedding Magnitude Comparison:")
+ # Check if the distance of each token embedding from the origin and compare
+ # if the vectors are on the same "sphere". This does not tell us about
+ # direction (meaning of the token embedding), just magnitude.
+ for i in range(n_tokens):
+ py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings
+ cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings
+ ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
+ print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
+
+ # 2. Cosine similarity between tokens within each model
+ # Here we check the direction of token embeddings to see if the have the
+ # same meaning (similarity). This is done by calculating cosine similarity
+ # of a pair of token embeddings within each model.
+ print(f"\n2. Within-Model Token Similarities:")
+ print(" Python model:")
+ for i in range(n_tokens):
+ for j in range(i+1, n_tokens):
+ sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0]
+ print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
+
+ print(" llama.cpp model:")
+ for i in range(n_tokens):
+ for j in range(i+1, n_tokens):
+ sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0]
+ print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
+
+ # 3. Cross-model similarity (same token position)
+ print(f"\n3. Cross-Model Same-Token Similarities:")
+ for i in range(n_tokens):
+ sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0]
+ print(f" Token {i} ({tokens[i]}): {sim:.4f}")
+
+ # 4. Similarity matrix comparison
+ print(f"\n4. Similarity Matrix Differences:")
+ py_sim_matrix = cosine_similarity(python_emb)
+ cpp_sim_matrix = cosine_similarity(cpp_emb)
+ diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix)
+
+ print(f" Max difference: {np.max(diff_matrix):.4f}")
+ print(f" Mean difference: {np.mean(diff_matrix):.4f}")
+ print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}")
+
+ return {
+ 'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)],
+ 'similarity_matrix_diff': diff_matrix,
+ 'max_diff': np.max(diff_matrix),
+ 'mean_diff': np.mean(diff_matrix),
+ 'rms_diff': np.sqrt(np.mean(diff_matrix**2))
+ }
def read_prompt_from_file(file_path):
try:
-r ./requirements-tool_bench.txt
-r ./requirements-gguf_editor_gui.txt
+
+-r ../examples/model-conversion/requirements.txt