]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
examples : add debug utility/example (#18464)
authorDaniel Bevenius <redacted>
Wed, 7 Jan 2026 09:42:19 +0000 (10:42 +0100)
committerGitHub <redacted>
Wed, 7 Jan 2026 09:42:19 +0000 (10:42 +0100)
* examples : add debug utility/example

This commit introduces a new example named llama-debug which is a
utility that is intended to be used to assist with developing/debugging
a converted model.

The motivation for this utilitiy is to assist in model conversion work
to verify that the model produces the expected outputs. It is intended
to replace logits.cpp in examples/model-conversion.

Example usage:
```console
./build/bin/llama-debug \
    -m models/Qwen2.5-0.5B-Instruct.gguf \
    --prompt "Hello, my name is" \
    --save-logits
...
Model add_bos: false
Input prompt: "Hello, my name is"
Token ids (5):
Hello(9707) ,(11)  my(847)  name(829)  is(374)
Data saved to data/llamacpp-Qwen2.5-0.5B-Instruct.bin
Data saved to data/llamacpp-Qwen2.5-0.5B-Instruct.txt
Prompt saved to data/llamacpp-Qwen2.5-0.5B-Instruct-prompt.txt
Tokens saved to data/llamacpp-Qwen2.5-0.5B-Instruct-tokens.bin
```

For more details about the options available for this example, please
refer to examples/debug/README.md.

* throw runtime error instead of logging error

* remove params.warmup and enable the warmup/nowarmup option

* model-conversion : remove logits.cpp

This commit removes logits.cpp in favor of using llama-debug for
generating logits and embeddings.

* examples : remove model-conversion directory

This was missed in the previous commit.

* model-conversion : add support for saving prompt and token ids

This commit add support for storing the prompt and the token ids for the
prompt when running the original models.

The motivation for this is that this will allow us to compare the prompt
and the tokens generated for the prompt when verifing the converted
model. Currently it is possible that even if the same prompt is used
that the tokens generated are different if there is a difference in the
tokenization between the original and converted model which would
currently go unnoticed (the verification will most likely fail but it
might not be obvious why).

* squash! model-conversion : add support for saving prompt and token ids

fix pyright errors.

* model-conversion : add compare_tokens utility

This commit adds a script to compare token outputs between original and
converted models.

Example usage:
```console
(venv) $ ./scripts/utils/compare_tokens.py pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16

Comparing tokens between:
  Original : pytorch-gemma-3-270m-it (6 tokens)
  Converted: llamacpp-gemma-3-270m-it-bf16 (6 tokens)

āœ… All 6 tokens match!
```
And there is a verbose flag that will also print out the prompts:
```console
(venv) $ ./scripts/utils/compare_tokens.py pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16 -v

Original model prompt (pytorch-gemma-3-270m-it):
  prompt: Hello, my name is
n_tokens: 6
token ids: 2, 9259, 236764, 1041, 1463, 563

Converted model prompt (llamacpp-gemma-3-270m-it-bf16):
  prompt: Hello, my name is
n_tokens: 6
token ids: 2, 9259, 236764, 1041, 1463, 563

Comparing tokens between:
  Original : pytorch-gemma-3-270m-it (6 tokens)
  Converted: llamacpp-gemma-3-270m-it-bf16 (6 tokens)

āœ… All 6 tokens match!
```

* model-conversion : add token comparison to verifiction scripts

This commit add the calling of the compare_tokens function in
compare-logits.py and semantic_check.py to ensure that the token ids
that the tokenizers procoduce are the same before proceeding with
verifying the logits/embeddings.

Placing them in the existing scripts instead calling them separately
ensures that the token comparison is always done prior to the
logit/embedding verifications.

Follow up commit/pr could refactor the causal logits verification into
a single script instead of the two that exist now. This would reduce the
code and make it consistent with the embeddings verficiation which only
has a single script.

* debug : use llama_model_n_embd_out

This commit updates the debug example to use the new function
llama_model_n_embd_out instead of llama_model_n_embd.

The motivation for this change is to support late interation retriever
models, like LFM2-ColBert-350M, where the output embeddings are down
projected to a lower dimension.

* debug : add print_usage function

This commit adds a print_usage function that is passed to the
common_params_parse.

The motivation for this is that this enables a specific usage message
which will be printed after all the options, for example:
```console
example usage:

  Print tensors:

  ./build/bin/llama-debug -m model.gguf -p "Hello my name is" --verbose

  The tensors to be printed can be filtered with --tensor-filter option.

  Save logits/embeddings:

  ./build/bin/llama-debug -m model.gguf -p "Hello my name is" --save-logits

  Add --embedding to save embeddings
```

18 files changed:
common/arg.cpp
common/common.h
examples/CMakeLists.txt
examples/debug/CMakeLists.txt [new file with mode: 0644]
examples/debug/README.md [new file with mode: 0644]
examples/debug/debug.cpp [new file with mode: 0644]
examples/model-conversion/CMakeLists.txt [deleted file]
examples/model-conversion/logits.cpp [deleted file]
examples/model-conversion/scripts/causal/compare-logits.py
examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py
examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh
examples/model-conversion/scripts/causal/run-converted-model.sh
examples/model-conversion/scripts/causal/run-org-model.py
examples/model-conversion/scripts/embedding/run-converted-model.sh
examples/model-conversion/scripts/embedding/run-original-model.py
examples/model-conversion/scripts/utils/common.py
examples/model-conversion/scripts/utils/compare_tokens.py [new file with mode: 0755]
examples/model-conversion/scripts/utils/semantic_check.py

index c3610d262b3ae3b36d1e72c839dea55803e76012..a67a26e2dc898a5cfce633a04299e1a22bdadaff 100644 (file)
@@ -1445,7 +1445,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         [](common_params & params, bool value) {
             params.warmup = value;
         }
-    ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY}));
+    ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_DEBUG}));
     add_opt(common_arg(
         {"--spm-infill"},
         string_format(
@@ -1761,7 +1761,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
             else { throw std::invalid_argument("invalid value"); }
         }
-    ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
+    ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_POOLING"));
     add_opt(common_arg(
         {"--attention"}, "{causal,non-causal}",
         "attention type for embeddings, use model default if unspecified",
@@ -2609,7 +2609,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         [](common_params & params, int value) {
             params.embd_normalize = value;
         }
-    ).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
+    ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_DEBUG}));
     add_opt(common_arg(
         {"--embd-output-format"}, "FORMAT",
         "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
@@ -2687,7 +2687,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         [](common_params & params) {
             params.embedding = true;
         }
-    ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
+    ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_EMBEDDINGS"));
     add_opt(common_arg(
         {"--rerank", "--reranking"},
         string_format("enable reranking endpoint on server (default: %s)", "disabled"),
@@ -3378,6 +3378,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             }
         }
     ).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
+    add_opt(common_arg(
+        {"--save-logits"},
+        string_format("save final logits to files for verification (default: %s)", params.save_logits ? "true" : "false"),
+        [](common_params & params) {
+            params.save_logits = true;
+        }
+    ).set_examples({LLAMA_EXAMPLE_DEBUG}));
+    add_opt(common_arg(
+        {"--logits-output-dir"}, "PATH",
+        string_format("directory for saving logits output files (default: %s)", params.logits_output_dir.c_str()),
+        [](common_params & params, const std::string & value) {
+            params.logits_output_dir = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_DEBUG}));
+    add_opt(common_arg(
+        {"--tensor-filter"}, "REGEX",
+        "filter tensor names for debug output (regex pattern, can be specified multiple times)",
+        [](common_params & params, const std::string & value) {
+            params.tensor_filter.push_back(value);
+        }
+    ).set_examples({LLAMA_EXAMPLE_DEBUG}));
 
     // presets
     add_opt(common_arg(
index daea6ded5ba537b65f871cea67c9630915a631bb..d6fd0d37a944bc11fe34ef81e646adb7a39967bc 100644 (file)
@@ -80,6 +80,7 @@ int32_t cpu_get_num_math();
 //
 
 enum llama_example {
+    LLAMA_EXAMPLE_DEBUG,
     LLAMA_EXAMPLE_COMMON,
     LLAMA_EXAMPLE_SPECULATIVE,
     LLAMA_EXAMPLE_COMPLETION,
@@ -372,6 +373,11 @@ struct common_params {
     std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding          // NOLINT
     std::string logits_file          = ""; // file for saving *all* logits                                  // NOLINT
 
+    // llama-debug specific options
+    std::string logits_output_dir = "data"; // directory for saving logits output files                     // NOLINT
+    bool        save_logits       = false;  // whether to save logits to files                              // NOLINT
+    std::vector<std::string> tensor_filter; // filter tensor names for debug output (regex)                 // NOLINT
+
     std::vector<std::string> in_files;   // all input files
     std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
     std::vector<llama_model_kv_override> kv_overrides;
index 91797cf78a929e333c9b796a2d7d06a39add82a5..a29dc707c3dc82d6d0207fe481235167dda7a489 100644 (file)
@@ -15,6 +15,7 @@ llama_add_compile_flags()
 if (EMSCRIPTEN)
 else()
     add_subdirectory(batched)
+    add_subdirectory(debug)
     add_subdirectory(embedding)
     add_subdirectory(eval-callback)
 
@@ -34,7 +35,6 @@ else()
     add_subdirectory(gen-docs)
     add_subdirectory(training)
     add_subdirectory(diffusion)
-    add_subdirectory(model-conversion)
     if (NOT GGML_BACKEND_DL)
         add_subdirectory(convert-llama2c-to-ggml)
         # these examples use the backends directly and cannot be built with dynamic loading
diff --git a/examples/debug/CMakeLists.txt b/examples/debug/CMakeLists.txt
new file mode 100644 (file)
index 0000000..3459307
--- /dev/null
@@ -0,0 +1,5 @@
+set(TARGET llama-debug)
+add_executable(${TARGET} debug.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/examples/debug/README.md b/examples/debug/README.md
new file mode 100644 (file)
index 0000000..28e00c9
--- /dev/null
@@ -0,0 +1,54 @@
+# llama.cpp/examples/debug
+
+This is a utility intended to help debug a model by registering a callback that
+logs GGML operations and tensor data. It can also store the generated logits or
+embeddings as well as the prompt and token ids for comparision with the original
+model.
+
+### Usage
+
+```shell
+llama-debug \
+  --hf-repo ggml-org/models \
+  --hf-file phi-2/ggml-model-q4_0.gguf \
+  --model phi-2-q4_0.gguf \
+  --prompt hello \
+  --save-logits \
+  --verbose
+```
+The tensor data is logged as debug and required the --verbose flag. The reason
+for this is that while useful for a model with many layers there can be a lot of
+output. You can filter the tensor names using the `--tensor-filter` option.
+
+A recommended approach is to first run without `--verbose` and see if the
+generated logits/embeddings are close to the original model. If they are not,
+then it might be required to inspect tensor by tensor and in that case it is
+useful to enable the `--verbose` flag along with `--tensor-filter` to focus on
+specific tensors.
+
+### Options
+This example supports all standard `llama.cpp` options and also accepts the
+following options:
+```console
+$ llama-debug --help
+...
+
+----- example-specific params -----
+
+--save-logits                           save final logits to files for verification (default: false)
+--logits-output-dir PATH                directory for saving logits output files (default: data)
+--tensor-filter REGEX                   filter tensor names for debug output (regex pattern, can be specified multiple times)
+```
+
+### Output Files
+
+When `--save-logits` is enabled, the following files are created in the output
+directory:
+
+* `llamacpp-<model>[-embeddings].bin`        - Binary output (logits or embeddings)
+* `llamacpp-<model>[-embeddings].txt`        - Text output (logits or embeddings, one per line)
+* `llamacpp-<model>[-embeddings]-prompt.txt` - Prompt text and token IDs
+* `llamacpp-<model>[-embeddings]-tokens.bin` - Binary token IDs for programmatic comparison
+
+These files can be compared against the original model's output to verify the
+converted model.
diff --git a/examples/debug/debug.cpp b/examples/debug/debug.cpp
new file mode 100644 (file)
index 0000000..9bc5d0a
--- /dev/null
@@ -0,0 +1,421 @@
+#include "arg.h"
+#include "common.h"
+#include "log.h"
+#include "llama.h"
+#include "ggml.h"
+
+#include <cmath>
+#include <cstdint>
+#include <cstdlib>
+#include <string>
+#include <vector>
+#include <filesystem>
+#include <fstream>
+#include <regex>
+
+static void print_usage(int, char ** argv) {
+    const std::string usage_template = R"(
+        example usage:
+
+          Print tensors:
+
+          {prog} -m model.gguf -p "Hello my name is" --verbose
+
+          The tensors to be printed can be filtered with --tensor-filter option.
+
+          Save logits/embeddings:
+
+          {prog} -m model.gguf -p "Hello my name is" --save-logits
+
+          Add --embedding to save embeddings)" "\n";
+
+    // Fix the source code indentation above that is introduced by the raw string literal.
+    std::string usage = std::regex_replace(usage_template, std::regex("\\n {8}"), "\n");
+    usage = std::regex_replace(usage, std::regex("\\{prog\\}"), argv[0]);
+    LOG("%s\n", usage.c_str());
+}
+
+static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data);
+
+struct callback_data {
+    std::vector<uint8_t>    data;
+    std::vector<std::regex> tensor_filters;
+
+    callback_data() = default;
+
+    callback_data(common_params & params, const std::vector<std::string> & filter_patterns) {
+        for (const auto & pattern : filter_patterns) {
+            try {
+                std::string anchored_pattern = "^" + pattern;
+                tensor_filters.emplace_back(anchored_pattern, std::regex::optimize);
+            } catch (const std::regex_error & e) {
+                throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what());
+            }
+        }
+        params.cb_eval           = ggml_debug;
+        params.cb_eval_user_data = this;
+    }
+};
+
+struct output_data {
+    float *                  data_ptr    = nullptr;
+    int                      data_size   = 0;
+    std::string              type_suffix;
+    std::vector<float>       storage;
+    std::string              prompt;
+    std::vector<llama_token> tokens;
+
+    output_data(llama_context * ctx, const llama_model * model, const common_params & params) {
+        const llama_vocab * vocab = llama_model_get_vocab(model);
+        const bool add_bos = llama_vocab_get_add_bos(vocab);
+
+        tokens = common_tokenize(ctx, params.prompt, add_bos);
+        prompt = params.prompt;
+
+        if (params.embedding) {
+            const int  n_embd          = llama_model_n_embd_out(model);
+            const bool pooling_enabled = llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE;
+            const int  n_embd_count    = pooling_enabled ? 1 : tokens.size();
+            const int  n_embeddings    = n_embd * n_embd_count;
+
+            float * embeddings;
+            if (pooling_enabled) {
+                embeddings = llama_get_embeddings_seq(ctx, 0);
+                storage.resize(n_embeddings);
+                common_embd_normalize(embeddings, storage.data(), n_embeddings, params.embd_normalize);
+                embeddings = storage.data();
+            } else {
+                embeddings = llama_get_embeddings(ctx);
+            }
+
+            data_ptr = embeddings;
+            data_size = n_embeddings;
+            type_suffix = "-embeddings";
+        } else {
+            const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1);
+            const int n_logits = llama_vocab_n_tokens(vocab);
+
+            data_ptr = const_cast<float*>(logits);
+            data_size = n_logits;
+            type_suffix = "";
+        }
+    }
+};
+
+static std::string ggml_ne_string(const ggml_tensor * t) {
+    std::string str;
+    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+        str += std::to_string(t->ne[i]);
+        if (i + 1 < GGML_MAX_DIMS) {
+            str += ", ";
+        }
+    }
+    return str;
+}
+
+static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
+    union {
+        float f;
+        uint32_t i;
+    } u;
+    u.i = (uint32_t)h.bits << 16;
+    return u.f;
+}
+
+static float ggml_get_float_value(const uint8_t * data, ggml_type type,
+        const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
+    size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
+    switch (type) {
+        case GGML_TYPE_F16:
+            return ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]);
+        case GGML_TYPE_F32:
+            return *(const float *) &data[i];
+        case GGML_TYPE_I64:
+            return (float) *(const int64_t *) &data[i];
+        case GGML_TYPE_I32:
+            return (float) *(const int32_t *) &data[i];
+        case GGML_TYPE_I16:
+            return (float) *(const int16_t *) &data[i];
+        case GGML_TYPE_I8:
+            return (float) *(const int8_t *) &data[i];
+        case GGML_TYPE_BF16:
+            return ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]);
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
+    GGML_ASSERT(n > 0);
+    float sum    = 0;
+    float sum_sq = 0.0;
+    for (int64_t i3 = 0; i3 < ne[3]; i3++) {
+        for (int64_t i2 = 0; i2 < ne[2]; i2++) {
+            for (int64_t i1 = 0; i1 < ne[1]; i1++) {
+                for (int64_t i0 = 0; i0 < ne[0]; i0++) {
+                    const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
+                    sum    += v;
+                    sum_sq += v * v;
+                }
+            }
+        }
+    }
+    for (int64_t i3 = 0; i3 < ne[3]; i3++) {
+        LOG_DBG("                                     [\n");
+        for (int64_t i2 = 0; i2 < ne[2]; i2++) {
+            if (i2 == n && ne[2] > 2*n) {
+                LOG_DBG("                                      ..., \n");
+                i2 = ne[2] - n;
+            }
+            LOG_DBG("                                      [\n");
+            for (int64_t i1 = 0; i1 < ne[1]; i1++) {
+                if (i1 == n && ne[1] > 2*n) {
+                    LOG_DBG("                                       ..., \n");
+                    i1 = ne[1] - n;
+                }
+                LOG_DBG("                                       [");
+                for (int64_t i0 = 0; i0 < ne[0]; i0++) {
+                    if (i0 == n && ne[0] > 2*n) {
+                        LOG_DBG("..., ");
+                        i0 = ne[0] - n;
+                    }
+                    const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
+                    LOG_DBG("%12.4f", v);
+                    if (i0 < ne[0] - 1) {
+                        LOG_DBG(", ");
+                    }
+                }
+                LOG_DBG("],\n");
+            }
+            LOG_DBG("                                      ],\n");
+        }
+        LOG_DBG("                                     ]\n");
+        LOG_DBG("                                     sum    = %f\n", sum);
+        LOG_DBG("                                     sum_sq = %f\n", sum_sq);
+    }
+
+    if (std::isnan(sum)) {
+        LOG_ERR("encountered NaN - aborting\n");
+        exit(0);
+    }
+}
+
+/**
+ * GGML operations callback during the graph execution.
+ *
+ * @param t current tensor
+ * @param ask when ask is true, the scheduler wants to know if we are interested in data from this tensor
+ *            if we return true, a follow-up call will be made with ask=false in which we can do the actual collection.
+ *            see ggml_backend_sched_eval_callback
+ * @param user_data user data to pass at each call back
+ * @return true to receive data or continue the graph, false otherwise
+ */
+static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
+    auto * cb_data = (callback_data *) user_data;
+
+    const struct ggml_tensor * src0 = t->src[0];
+    const struct ggml_tensor * src1 = t->src[1];
+
+    if (ask) {
+        return true; // Always retrieve data
+    }
+
+    bool matches_filter = cb_data->tensor_filters.empty();
+
+    if (!matches_filter) {
+        for (const auto & filter : cb_data->tensor_filters) {
+            if (std::regex_search(t->name, filter)) {
+                matches_filter = true;
+                break;
+            }
+        }
+    }
+
+    char src1_str[128] = {0};
+    if (src1) {
+        snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
+    }
+
+    if (matches_filter) {
+        LOG_DBG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
+             t->name,
+             ggml_type_name(t->type),
+             ggml_op_desc(t),
+             src0->name,
+             ggml_ne_string(src0).c_str(),
+             src1 ? src1_str : "",
+             ggml_ne_string(t).c_str());
+    }
+
+    const bool is_host = ggml_backend_buffer_is_host(t->buffer);
+
+    if (!is_host) {
+        auto n_bytes = ggml_nbytes(t);
+        cb_data->data.resize(n_bytes);
+        ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes);
+    }
+
+    if (!ggml_is_quantized(t->type) && matches_filter) {
+        uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
+        ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
+    }
+
+    return true;
+}
+
+
+static void save_output_data(const output_data & output, const std::string & model_name, const std::string & output_dir) {
+    std::filesystem::create_directory(output_dir);
+    auto base_path = std::filesystem::path{output_dir} / ("llamacpp-" + model_name + output.type_suffix);
+
+    // Save logits/embeddings to binary file.
+    {
+        std::filesystem::path filepath{base_path.string() + ".bin"};
+        std::ofstream file{filepath, std::ios::binary};
+        if (!file) {
+            throw std::runtime_error("failed to open binary output file: " + filepath.string());
+        }
+        file.write(reinterpret_cast<const char*>(output.data_ptr), output.data_size * sizeof(float));
+        LOG("Data saved to %s\n", filepath.c_str());
+    }
+
+    // Save logits/embeddings to text file.
+    {
+        std::filesystem::path filepath{base_path.string() + ".txt"};
+        std::ofstream file{filepath};
+        if (!file) {
+            throw std::runtime_error("failed to open text output file: " + filepath.string());
+        }
+        for (int i = 0; i < output.data_size; i++) {
+            file << i << ": " << output.data_ptr[i] << '\n';
+        }
+        LOG("Data saved to %s\n", filepath.c_str());
+    }
+
+    // Save prompt and tokens to text file.
+    {
+        std::filesystem::path filepath{base_path.string() + "-prompt.txt"};
+        std::ofstream file{filepath};
+        if (!file) {
+            throw std::runtime_error("failed to open prompt output file: " + filepath.string());
+        }
+
+        file << "prompt: " << output.prompt << '\n';
+        file << "n_tokens: " << output.tokens.size() << '\n';
+
+        file << "token ids: ";
+        for (size_t i = 0; i < output.tokens.size(); i++) {
+            file << output.tokens[i];
+            if (i + 1 < output.tokens.size()) {
+                file << ", ";
+            }
+        }
+        file << '\n';
+        LOG("Prompt saved to %s\n", filepath.c_str());
+    }
+
+    // Save token ids to binary file.
+    {
+        std::filesystem::path filepath{base_path.string() + "-tokens.bin"};
+        std::ofstream file{filepath, std::ios::binary};
+        if (!file) {
+            throw std::runtime_error("failed to open tokens binary file: " + filepath.string());
+        }
+        file.write(reinterpret_cast<const char*>(output.tokens.data()), output.tokens.size() * sizeof(llama_token));
+        LOG("Tokens saved to %s\n", filepath.c_str());
+    }
+
+}
+
+static void print_tokenized_prompt(llama_context * ctx, const std::vector<llama_token> & tokens, const std::string & prompt) {
+    const llama_model * model = llama_get_model(ctx);
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    LOG("Model add_bos: %s\n", llama_vocab_get_add_bos(vocab) ? "true" : "false");
+    LOG("Input prompt: \"%s\"\n", prompt.c_str());
+    LOG("Token ids (%zu):\n", tokens.size());
+
+    for (auto id : tokens) {
+        std::string piece(128, '\0');
+        int n = llama_token_to_piece(vocab, id, piece.data(), piece.size(), 0, true);
+        if (n < 0) {
+            LOG_ERR("failed to convert token %d to piece\n", id);
+            continue;
+        }
+        piece.resize(n);
+        LOG("%s(%d) ", piece.c_str(), id);
+    }
+    LOG("\n");
+}
+
+static bool run(llama_context * ctx, const common_params & params) {
+    const llama_model * model = llama_get_model(ctx);
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    const bool add_bos = llama_vocab_get_add_bos(vocab);
+
+    std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
+
+    if (tokens.empty()) {
+        LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
+        return false;
+    }
+
+    if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
+        LOG_ERR("%s : failed to eval\n", __func__);
+        return false;
+    }
+
+    print_tokenized_prompt(ctx, tokens, params.prompt);
+
+    if (params.save_logits) {
+        output_data output {ctx, model, params};
+        std::filesystem::path model_path{params.model.path};
+        std::string model_name{model_path.stem().string()};
+        save_output_data(output, model_name, params.logits_output_dir);
+    }
+
+    return true;
+}
+
+int main(int argc, char ** argv) {
+    common_params params;
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG, print_usage)) {
+        return 1;
+    }
+
+    common_init();
+
+    llama_backend_init();
+    llama_numa_init(params.numa);
+
+    callback_data cb_data(params, params.tensor_filter);
+
+    auto llama_init = common_init_from_params(params);
+
+    auto * model = llama_init->model();
+    auto * ctx   = llama_init->context();
+
+    if (model == nullptr || ctx == nullptr) {
+        LOG_ERR("%s : failed to init\n", __func__);
+        return 1;
+    }
+
+    {
+        LOG_INF("\n");
+        LOG_INF("%s\n", common_params_get_system_info(params).c_str());
+        LOG_INF("\n");
+    }
+
+    if (!run(ctx, params)) {
+        return 1;
+    }
+
+    LOG("\n");
+    llama_perf_context_print(ctx);
+
+    llama_backend_free();
+
+    return 0;
+}
diff --git a/examples/model-conversion/CMakeLists.txt b/examples/model-conversion/CMakeLists.txt
deleted file mode 100644 (file)
index fc1746c..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-set(TARGET llama-logits)
-add_executable(${TARGET} logits.cpp)
-install(TARGETS ${TARGET} RUNTIME)
-target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
-target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/examples/model-conversion/logits.cpp b/examples/model-conversion/logits.cpp
deleted file mode 100644 (file)
index f71f772..0000000
+++ /dev/null
@@ -1,268 +0,0 @@
-#include "llama.h"
-#include "common.h"
-
-
-#include <cstdio>
-#include <cstring>
-#include <string>
-#include <vector>
-#include <ctype.h>
-#include <filesystem>
-
-static void print_usage(int, char ** argv) {
-    printf("\nexample usage:\n");
-    printf("\n    %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm <norm>] [prompt]\n", argv[0]);
-    printf("\n");
-    printf("  -embd-norm: normalization type for pooled embeddings (default: 2)\n");
-    printf("              -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n");
-    printf("\n");
-}
-
-int main(int argc, char ** argv) {
-    std::string model_path;
-    std::string prompt = "Hello, my name is";
-    int ngl = 0;
-    bool embedding_mode = false;
-    bool pooling_enabled = false;
-    int32_t embd_norm = 2;  // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
-
-    {
-        int i = 1;
-        for (; i < argc; i++) {
-            if (strcmp(argv[i], "-m") == 0) {
-                if (i + 1 < argc) {
-                    model_path = argv[++i];
-                } else {
-                    print_usage(argc, argv);
-                    return 1;
-                }
-            } else if (strcmp(argv[i], "-ngl") == 0) {
-                if (i + 1 < argc) {
-                    try {
-                        ngl = std::stoi(argv[++i]);
-                    } catch (...) {
-                        print_usage(argc, argv);
-                        return 1;
-                    }
-                } else {
-                    print_usage(argc, argv);
-                    return 1;
-                }
-            } else if (strcmp(argv[i], "-embd-mode") == 0) {
-                embedding_mode = true;
-            } else if (strcmp(argv[i], "-pooling") == 0) {
-                pooling_enabled = true;
-            } else if (strcmp(argv[i], "-embd-norm") == 0) {
-                if (i + 1 < argc) {
-                    try {
-                        embd_norm = std::stoi(argv[++i]);
-                    } catch (...) {
-                        print_usage(argc, argv);
-                        return 1;
-                    }
-                } else {
-                    print_usage(argc, argv);
-                    return 1;
-                }
-            } else {
-                // prompt starts here
-                break;
-            }
-        }
-
-        if (model_path.empty()) {
-            print_usage(argc, argv);
-            return 1;
-        }
-
-        if (i < argc) {
-            prompt = argv[i++];
-            for (; i < argc; i++) {
-                prompt += " ";
-                prompt += argv[i];
-            }
-        }
-    }
-
-    ggml_backend_load_all();
-    llama_model_params model_params = llama_model_default_params();
-    model_params.n_gpu_layers = ngl;
-
-    llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
-
-    if (model == NULL) {
-        fprintf(stderr , "%s: error: unable to load model\n" , __func__);
-        return 1;
-    }
-
-    // Extract basename from model_path
-    const char * basename = strrchr(model_path.c_str(), '/');
-    basename = (basename == NULL) ? model_path.c_str() : basename + 1;
-
-    char model_name[256];
-    strncpy(model_name, basename, 255);
-    model_name[255] = '\0';
-
-    char * dot = strrchr(model_name, '.');
-    if (dot != NULL && strcmp(dot, ".gguf") == 0) {
-        *dot = '\0';
-    }
-    printf("Model name: %s\n", model_name);
-
-    const llama_vocab * vocab = llama_model_get_vocab(model);
-    const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
-
-    std::vector<llama_token> prompt_tokens(n_prompt);
-    if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
-        fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__);
-        return 1;
-    }
-
-    llama_context_params ctx_params = llama_context_default_params();
-    ctx_params.n_ctx = n_prompt;
-    ctx_params.n_batch = n_prompt;
-    ctx_params.no_perf = false;
-    if (embedding_mode) {
-        ctx_params.embeddings = true;
-        ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE;
-        ctx_params.n_ubatch = ctx_params.n_batch;
-    }
-
-    llama_context * ctx = llama_init_from_model(model, ctx_params);
-    if (ctx == NULL) {
-        fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
-        return 1;
-    }
-
-    printf("Input prompt: \"%s\"\n", prompt.c_str());
-    printf("Tokenized prompt (%d tokens): ", n_prompt);
-    for (auto id : prompt_tokens) {
-        char buf[128];
-        int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true);
-        if (n < 0) {
-            fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
-            return 1;
-        }
-        std::string s(buf, n);
-        printf("%s (%d)", s.c_str(), id);
-    }
-    printf("\n");
-
-    llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
-
-    if (llama_decode(ctx, batch)) {
-        fprintf(stderr, "%s : failed to eval\n", __func__);
-        return 1;
-    }
-
-    float * data_ptr;
-    int data_size;
-    const char * type;
-    std::vector<float> embd_out;
-
-    if (embedding_mode) {
-        const int n_embd_out = llama_model_n_embd_out(model);
-        const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
-        const int n_embeddings = n_embd_out * n_embd_count;
-        float * embeddings;
-        type = "-embeddings";
-
-        if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
-            embeddings = llama_get_embeddings_seq(ctx, 0);
-            embd_out.resize(n_embeddings);
-            printf("Normalizing embeddings using norm: %d\n", embd_norm);
-            common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm);
-            embeddings = embd_out.data();
-        } else {
-            embeddings = llama_get_embeddings(ctx);
-        }
-
-        printf("Embedding dimension: %d\n", n_embd_out);
-        printf("\n");
-
-        // Print embeddings in the specified format
-        for (int j = 0; j < n_embd_count; j++) {
-            printf("embedding %d: ", j);
-
-            // Print first 3 values
-            for (int i = 0; i < 3 && i < n_embd_out; i++) {
-                printf("%9.6f ", embeddings[j * n_embd_out + i]);
-            }
-
-            printf(" ... ");
-
-            // Print last 3 values
-            for (int i = n_embd_out - 3; i < n_embd_out; i++) {
-                if (i >= 0) {
-                    printf("%9.6f ", embeddings[j * n_embd_out + i]);
-                }
-            }
-
-            printf("\n");
-        }
-        printf("\n");
-
-        printf("Embeddings size: %d\n", n_embeddings);
-
-        data_ptr = embeddings;
-        data_size = n_embeddings;
-    } else {
-        float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
-        const int n_logits = llama_vocab_n_tokens(vocab);
-        type = "";
-        printf("Vocab size: %d\n", n_logits);
-
-        data_ptr = logits;
-        data_size = n_logits;
-    }
-
-    std::filesystem::create_directory("data");
-
-    // Save data to binary file
-    char bin_filename[512];
-    snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type);
-    printf("Saving data to %s\n", bin_filename);
-
-    FILE * f = fopen(bin_filename, "wb");
-    if (f == NULL) {
-        fprintf(stderr, "%s: error: failed to open binary output file\n", __func__);
-        return 1;
-    }
-    fwrite(data_ptr, sizeof(float), data_size, f);
-    fclose(f);
-
-    // Also save as text for debugging
-    char txt_filename[512];
-    snprintf(txt_filename, sizeof(txt_filename), "data/llamacpp-%s%s.txt", model_name, type);
-    f = fopen(txt_filename, "w");
-    if (f == NULL) {
-        fprintf(stderr, "%s: error: failed to open text output file\n", __func__);
-        return 1;
-    }
-    for (int i = 0; i < data_size; i++) {
-        fprintf(f, "%d: %.6f\n", i, data_ptr[i]);
-    }
-    fclose(f);
-
-    if (!embedding_mode) {
-        printf("First 10 logits: ");
-        for (int i = 0; i < 10 && i < data_size; i++) {
-            printf("%.6f ", data_ptr[i]);
-        }
-        printf("\n");
-
-        printf("Last 10 logits: ");
-        for (int i = data_size - 10; i < data_size; i++) {
-            if (i >= 0) printf("%.6f ", data_ptr[i]);
-        }
-        printf("\n\n");
-    }
-
-    printf("Data saved to %s\n", bin_filename);
-    printf("Data saved to %s\n", txt_filename);
-
-    llama_free(ctx);
-    llama_model_free(model);
-
-    return 0;
-}
index 894302c69eb78b8c29c71494773ba1e9de516da5..1a933207d59b0995ab260d0642611e3fb8d7489a 100755 (executable)
@@ -6,7 +6,7 @@ from pathlib import Path
 
 # Add utils directory to path for direct script execution
 sys.path.insert(0, str(Path(__file__).parent.parent / "utils"))
-from common import get_model_name_from_env_path  # type: ignore[import-not-found]
+from common import get_model_name_from_env_path, compare_tokens  # type: ignore[import-not-found]
 
 def quick_logits_check(pytorch_file, llamacpp_file):
     """Lightweight sanity check before NMSE"""
@@ -58,6 +58,13 @@ def main():
 
     print("Checked all required files were found. Proceeding...\n")
 
+    # Verify tokens as they are a prerequisite for logits comparison.
+    print("šŸ” Token Comparison Check")
+    print("=" * 40)
+    if not compare_tokens(f"pytorch-{model_name}", f"llamacpp-{llamacpp_model_name}"):
+        print("\nāŒ Token mismatch detected")
+        sys.exit(1)
+    print()
 
     print("šŸ” GGML Model Validation for model ", model_name)
     print("=" * 40)
index 55ad821385f32c4a8633e6f14c69fb497bdea35e..4ab778fbc79000af6f337e19779f10ad3baecfbb 100755 (executable)
@@ -67,7 +67,7 @@ with torch.no_grad():
     last_hidden_states = outputs.hidden_states[-1]
 
     # Get embeddings for all tokens
-    token_embeddings = last_hidden_states[0].cpu().numpy()  # Remove batch dimension
+    token_embeddings = last_hidden_states[0].float().cpu().numpy()  # Remove batch dimension
 
     print(f"Hidden states shape: {last_hidden_states.shape}")
     print(f"Token embeddings shape: {token_embeddings.shape}")
index fa16a02c6599c3a105ea839d0cdc3c3a1d0e240a..3cce3fc94dd361e7f7d05a5465adfd38e521d0d8 100755 (executable)
@@ -13,6 +13,6 @@ if [ -z "$CONVERTED_MODEL" ]; then
     exit 1
 fi
 
-cmake --build ../../build --target llama-logits -j8
+cmake --build ../../build --target llama-debug -j8
 
-../../build/bin/llama-logits -m $CONVERTED_MODEL -embd-mode "Hello world today"
+../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits
index 529e9987b01978b0ee2091f86e65b7ae976933b6..b6c3d386628fbf904a0e9f6a21099a63c66b31a7 100755 (executable)
@@ -21,6 +21,6 @@ fi
 echo $CONVERTED_MODEL
 echo $MODEL_TESTING_PROMPT
 
-cmake --build ../../build --target llama-logits -j8
+cmake --build ../../build --target llama-debug -j8
 
-../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"
+../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits
index b12173a1fba40d04bde12578bda8a89cab5654e1..215f1a9ee0cf73551806aa2da36419d7d41e843d 100755 (executable)
@@ -7,12 +7,11 @@ import importlib
 import torch
 import numpy as np
 
-from pathlib import Path
 from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
 
 # Add parent directory to path for imports
 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
-from utils.common import debug_hook
+from utils.common import debug_hook, save_output_data
 
 def parse_arguments():
     parser = argparse.ArgumentParser(description="Process model with specified path")
@@ -126,6 +125,7 @@ def main():
     device = next(model.parameters()).device
     prompt = get_prompt(args)
     input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
+    token_ids = input_ids[0].cpu().tolist()
 
     print(f"Input tokens: {input_ids}")
     print(f"Input text: {repr(prompt)}")
@@ -151,19 +151,6 @@ def main():
         print(f"Last token logits shape: {last_logits.shape}")
         print(f"Vocab size: {len(last_logits)}")
 
-        data_dir = Path("data")
-        data_dir.mkdir(exist_ok=True)
-        bin_filename = data_dir / f"pytorch-{model_name}.bin"
-        txt_filename = data_dir / f"pytorch-{model_name}.txt"
-
-        # Save to file for comparison
-        last_logits.astype(np.float32).tofile(bin_filename)
-
-        # Also save as text file for easy inspection
-        with open(txt_filename, "w") as f:
-            for i, logit in enumerate(last_logits):
-                f.write(f"{i}: {logit:.6f}\n")
-
         # Print some sample logits for quick verification
         print(f"First 10 logits: {last_logits[:10]}")
         print(f"Last 10 logits: {last_logits[-10:]}")
@@ -175,8 +162,7 @@ def main():
             token = tokenizer.decode([idx])
             print(f"  Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
 
-        print(f"Saved bin logits to: {bin_filename}")
-        print(f"Saved txt logist to: {txt_filename}")
+        save_output_data(last_logits, token_ids, prompt, model_name)
 
 if __name__ == "__main__":
     main()
index 0f490e6c3b20aac378907f65ec4783eb995697a2..5d264b066387b36b170a666a8fbe74382c598a32 100755 (executable)
@@ -50,10 +50,9 @@ fi
 
 echo $CONVERTED_MODEL
 
-cmake --build ../../build --target llama-logits -j8
-# TODO: update logits.cpp to accept a --file/-f option for the prompt
+cmake --build ../../build --target llama-debug -j8
 if [ -n "$USE_POOLING" ]; then
-    ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT"
+    ../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling mean -p "$PROMPT" --save-logits
 else
-    ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
+    ../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling none -p "$PROMPT" --save-logits
 fi
index 774e5638f7238c379077603c04e293d1d679e361..0802cbcf4a2f5eb36524f2e346ce4cf50f2a28ac 100755 (executable)
@@ -3,13 +3,15 @@
 import argparse
 import os
 import sys
-import numpy as np
 import importlib
-from pathlib import Path
 
 from transformers import AutoTokenizer, AutoConfig, AutoModel
 import torch
 
+# Add parent directory to path for imports
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+from utils.common import save_output_data
+
 
 def parse_arguments():
     parser = argparse.ArgumentParser(description='Run original embedding model')
@@ -169,6 +171,7 @@ def main():
                 return_tensors="pt"
             )
             tokens = encoded['input_ids'][0]
+            token_ids = tokens.cpu().tolist()
             token_strings = tokenizer.convert_ids_to_tokens(tokens)
             for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
                 print(f"{token_id:6d} -> '{token_str}'")
@@ -185,6 +188,7 @@ def main():
             )
 
             tokens = encoded['input_ids'][0]
+            token_ids = tokens.cpu().tolist()
             token_strings = tokenizer.convert_ids_to_tokens(tokens)
             for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
                 print(f"{token_id:6d} -> '{token_str}'")
@@ -228,24 +232,11 @@ def main():
 
         print()
 
-        data_dir = Path("data")
-        data_dir.mkdir(exist_ok=True)
-        bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
-        txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
-
         flattened_embeddings = all_embeddings.flatten()
-        flattened_embeddings.astype(np.float32).tofile(bin_filename)
-
-        with open(txt_filename, "w") as f:
-            idx = 0
-            for j in range(n_embd_count):
-                for value in all_embeddings[j]:
-                    f.write(f"{idx}: {value:.6f}\n")
-                    idx += 1
         print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings Ć— {n_embd} dimensions)")
         print("")
-        print(f"Saved bin embeddings to: {bin_filename}")
-        print(f"Saved txt embeddings to: {txt_filename}")
+
+        save_output_data(flattened_embeddings, token_ids, prompt_text, model_name, type_suffix="-embeddings")
 
 
 if __name__ == "__main__":
index 7595d0410edef5d9af41976585f46bd8a5d04f61..71761127bb0e3204c6934825efd814080347adc9 100644 (file)
@@ -3,6 +3,8 @@
 import os
 import sys
 import torch
+import numpy as np
+from pathlib import Path
 
 
 def get_model_name_from_env_path(env_path_name):
@@ -148,3 +150,96 @@ def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_
     # Patch it
     setattr(module, function_name, debug_rope)
     print(f"RoPE debug patching applied to {model_module_path}.{function_name}")
+
+
+def save_output_data(data, tokens, prompt, model_name, type_suffix="", output_dir="data"):
+    """
+    Save output data (logits/embeddings), tokens, and prompt to files.
+
+    Args:
+        data:        numpy array of floats (logits or embeddings)
+        tokens:      list or array of token IDs
+        prompt:      string containing the input prompt
+        model_name:  name of the model
+        type_suffix: optional suffix like "-embeddings" (default: "")
+        output_dir:  directory to save files (default: "data")
+
+    Creates the following files in output_dir:
+        - pytorch-{model_name}{type_suffix}.bin
+        - pytorch-{model_name}{type_suffix}.txt
+        - pytorch-{model_name}{type_suffix}-prompt.txt
+        - pytorch-{model_name}{type_suffix}-tokens.bin
+    """
+    data_dir = Path(output_dir)
+    data_dir.mkdir(exist_ok=True)
+    base_path = data_dir / f"pytorch-{model_name}{type_suffix}"
+
+    # Convert and flatten logits/embeddings
+    data = data.cpu().numpy() if isinstance(data, torch.Tensor) else np.asarray(data)
+    data = data.flatten() if data.ndim > 1 else data
+
+    # Save logits/embedding files
+    data.astype(np.float32).tofile(f"{base_path}.bin")
+    print(f"Data saved to {base_path}.bin")
+
+    with open(f"{base_path}.txt", "w") as f:
+        f.writelines(f"{i}: {value:.6f}\n" for i, value in enumerate(data))
+    print(f"Data saved to {base_path}.txt")
+
+    # Convert and flatten tokens
+    tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.asarray(tokens)
+    tokens = tokens.flatten() if tokens.ndim > 1 else tokens
+
+    # Save token binary file
+    tokens.astype(np.int32).tofile(f"{base_path}-tokens.bin")
+    print(f"Tokens saved to {base_path}-tokens.bin")
+
+    # Save prompt file
+    with open(f"{base_path}-prompt.txt", "w") as f:
+        f.write(f"prompt: {prompt}\n")
+        f.write(f"n_tokens: {len(tokens)}\n")
+        f.write(f"token ids: {', '.join(str(int(tid)) for tid in tokens)}\n")
+    print(f"Prompt saved to {base_path}-prompt.txt")
+
+
+def compare_tokens(original, converted, type_suffix="", output_dir="data"):
+    data_dir = Path(output_dir)
+
+    # Read tokens from both models
+    tokens1_file = data_dir / f"{original}{type_suffix}-tokens.bin"
+    tokens2_file = data_dir / f"{converted}{type_suffix}-tokens.bin"
+
+    if not tokens1_file.exists():
+        print(f"Error: Token file not found: {tokens1_file}")
+        return False
+
+    if not tokens2_file.exists():
+        print(f"Error: Token file not found: {tokens2_file}")
+        return False
+
+    tokens1 = np.fromfile(tokens1_file, dtype=np.int32)
+    tokens2 = np.fromfile(tokens2_file, dtype=np.int32)
+
+    print(f"\nComparing tokens between:")
+    print(f"  Original : {original} ({len(tokens1)} tokens)")
+    print(f"  Converted: {converted} ({len(tokens2)} tokens)")
+
+    if len(tokens1) != len(tokens2):
+        print(f"\nāŒ Token count mismatch: {len(tokens1)} vs {len(tokens2)}")
+        return False
+
+    if np.array_equal(tokens1, tokens2):
+        print(f"\nāœ… All {len(tokens1)} tokens match!")
+        return True
+
+    mismatches = np.where(tokens1 != tokens2)[0]
+    print(f"\nāŒ Found {len(mismatches)} mismatched tokens:")
+
+    num_to_show = min(len(mismatches), 10)
+    for idx in mismatches[:num_to_show]:
+        print(f"  Position {idx}: {tokens1[idx]} vs {tokens2[idx]}")
+
+    if len(mismatches) > num_to_show:
+        print(f"  ... and {len(mismatches) - num_to_show} more mismatches")
+
+    return False
diff --git a/examples/model-conversion/scripts/utils/compare_tokens.py b/examples/model-conversion/scripts/utils/compare_tokens.py
new file mode 100755 (executable)
index 0000000..a286cb5
--- /dev/null
@@ -0,0 +1,76 @@
+#!/usr/bin/env python3
+
+import argparse
+import sys
+from common import compare_tokens  # type: ignore
+
+
+def parse_arguments():
+    parser = argparse.ArgumentParser(
+        description='Compare tokens between two models',
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog="""
+Examples:
+  %(prog)s pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16
+        """
+    )
+    parser.add_argument(
+        'original',
+        help='Original model name'
+    )
+    parser.add_argument(
+        'converted',
+        help='Converted model name'
+    )
+    parser.add_argument(
+        '-s', '--suffix',
+        default='',
+        help='Type suffix (e.g., "-embeddings")'
+    )
+    parser.add_argument(
+        '-d', '--data-dir',
+        default='data',
+        help='Directory containing token files (default: data)'
+    )
+    parser.add_argument(
+        '-v', '--verbose',
+        action='store_true',
+        help='Print prompts from both models'
+    )
+    return parser.parse_args()
+
+
+def main():
+    args = parse_arguments()
+
+    if args.verbose:
+        from pathlib import Path
+        data_dir = Path(args.data_dir)
+
+        prompt1_file = data_dir / f"{args.original}{args.suffix}-prompt.txt"
+        prompt2_file = data_dir / f"{args.converted}{args.suffix}-prompt.txt"
+
+        if prompt1_file.exists():
+            print(f"\nOriginal model prompt ({args.original}):")
+            print(f"  {prompt1_file.read_text().strip()}")
+
+        if prompt2_file.exists():
+            print(f"\nConverted model prompt ({args.converted}):")
+            print(f"  {prompt2_file.read_text().strip()}")
+
+        print()
+
+    result = compare_tokens(
+        args.original,
+        args.converted,
+        type_suffix=args.suffix,
+        output_dir=args.data_dir
+    )
+
+    # Enable the script to be used in shell scripts so that they can check
+    # the exit code for success/failure.
+    sys.exit(0 if result else 1)
+
+
+if __name__ == "__main__":
+    main()
index e64c0004974e63643a06f887c54811d6e25c2732..38b03ce4d2cc0dbd3362487138230be8f2cfed5f 100644 (file)
@@ -4,8 +4,10 @@ import numpy as np
 import argparse
 import os
 import importlib
+from pathlib import Path
 
 from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel
+from common import compare_tokens  # type: ignore[import-not-found]
 
 unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
 
@@ -157,9 +159,25 @@ def main():
     else:
         prompt = args.prompt
 
+    python_emb_path = Path(args.python_embeddings)
+    cpp_emb_path = Path(args.cpp_embeddings)
+
+    # Extract base names (e.g., "pytorch-model-name-embeddings.bin" -> "pytorch-model-name")
+    python_model_name = python_emb_path.stem.replace("-embeddings", "")
+    cpp_model_name = cpp_emb_path.stem.replace("-embeddings", "")
+
     print("Semantic Similarity Test Between Python and llama.cpp Embedding Models")
     print("=" * 70)
 
+    # First verify tokens match before comparing embeddings
+    print("\nšŸ” Token Comparison Check")
+    print("=" * 70)
+    data_dir = python_emb_path.parent
+    if not compare_tokens(python_model_name, cpp_model_name, type_suffix="-embeddings", output_dir=str(data_dir)):
+        print("\nāŒ Token mismatch detected")
+        exit(1)
+    print()
+
     # Single prompt detailed comparison
     print(f"\nTesting with prompt: '{prompt}'")