]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : support multiple classifier outputs and labels (#13940)
authorSigbjørn Skjæret <redacted>
Fri, 6 Jun 2025 07:03:25 +0000 (09:03 +0200)
committerGitHub <redacted>
Fri, 6 Jun 2025 07:03:25 +0000 (09:03 +0200)
examples/embedding/embedding.cpp
include/llama.h
src/llama-context.cpp
src/llama-model-loader.cpp
src/llama-model.cpp
src/llama-model.h

index 71f700877a3b9faca3247231153b631e3f135174..8bef7f8f6ba25a085fbd80bdbf30d57a9f724787 100644 (file)
@@ -236,9 +236,24 @@ int main(int argc, char ** argv) {
                 LOG("\n");
             }
         } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
+            const uint32_t n_cls_out = llama_model_n_cls_out(model);
+            std::vector<std::string> cls_out_labels;
+
+            for (uint32_t i = 0; i < n_cls_out; i++) {
+                const char * label = llama_model_cls_label(model, i);
+                const std::string label_i(label == nullptr ? "" : label);
+                cls_out_labels.emplace_back(label_i.empty() ? std::to_string(i) : label_i);
+            }
+
             for (int j = 0; j < n_embd_count; j++) {
-                // NOTE: if you change this log - update the tests in ci/run.sh
-                LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
+                for (uint32_t i = 0; i < n_cls_out; i++) {
+                    // NOTE: if you change this log - update the tests in ci/run.sh
+                    if (n_cls_out == 1) {
+                        LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
+                    } else {
+                        LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
+                    }
+                }
             }
         } else {
             // print the first part of the embeddings or for a single prompt, the full embedding
index 21808c881c7b3ff27a84ac9171ea59442a9f5d4d..aa5330e2a5977bdafbd150a45f9bbdab8391a8bf 100644 (file)
@@ -514,6 +514,13 @@ extern "C" {
     // Get the model's RoPE frequency scaling factor
     LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
 
+    // Returns the number of classifier outputs (only valid for classifier models)
+    // Undefined behavior for non-classifier models
+    LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
+
+    // Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
+    LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
+
     LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
 
     LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
@@ -992,7 +999,7 @@ extern "C" {
 
     // Get the embeddings for a sequence id
     // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
-    // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
+    // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence
     // otherwise: float[n_embd] (1-dimensional)
     LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
 
index c29fe7e4ce41dc8caf52075e463149d968051924..d94bf8643492197db9e34ce963825338090260e6 100644 (file)
@@ -839,16 +839,17 @@ int llama_context::encode(llama_batch & inp_batch) {
                 } break;
             case LLAMA_POOLING_TYPE_RANK:
                 {
-                    // extract the rerank score - a single float per sequence
+                    // extract the rerank score - n_cls_out floats per sequence
                     auto & embd_seq_out = embd_seq;
+                    const uint32_t n_cls_out = hparams.n_cls_out;
 
                     for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
                         const llama_seq_id seq_id = ubatch.seq_id[s][0];
                         if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
                             continue;
                         }
-                        embd_seq_out[seq_id].resize(1);
-                        ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
+                        embd_seq_out[seq_id].resize(n_cls_out);
+                        ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
                     }
                 } break;
             case LLAMA_POOLING_TYPE_UNSPECIFIED:
index ddb1b03675b289acdf53f66c4e5cb1c2ad80b589..bd9e6da8832b78c7d5a1f4661ef84c33269bea10 100644 (file)
@@ -288,9 +288,10 @@ namespace GGUFMeta {
 
     template<typename T>
     bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
-        const int kid = gguf_find_key(meta.get(), key.c_str());
+        const gguf_context * ctx = meta.get();
+        const int kid = gguf_find_key(ctx, key.c_str());
 
-        if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
+        if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
             if (required) {
                 throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
             }
@@ -298,28 +299,40 @@ namespace GGUFMeta {
         }
 
         struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
+            GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
 
         switch (arr_info.gt) {
             case GGUF_TYPE_UINT32:
-            case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,  int32_t>::value) ||
-                                                (std::is_same<T, uint32_t>::value)); break;
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,    float>::value)); break;
+            case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,     int32_t>::value) ||
+                                                (std::is_same<T,    uint32_t>::value)); break;
+            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,       float>::value)); break;
+            case GGUF_TYPE_STRING:  GGML_ASSERT((std::is_same<T, std::string>::value)); break;
             default:
-                throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
+                throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
         }
 
-        result.resize(arr_info.length);
-        result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
+        if constexpr (std::is_same<T, std::string>::value) {
+            const size_t n_items = gguf_get_arr_n(ctx, kid);
+            result.clear();
+
+            for (size_t i = 0; i < n_items; i++) {
+                const T value = gguf_get_arr_str(ctx, kid, i);
+                result.emplace_back(value);
+            }
+        } else {
+            result.resize(arr_info.length);
+            result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
+        }
 
         return true;
     }
 
     template<typename T, size_t N_MAX>
     bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
-        const int kid = gguf_find_key(meta.get(), key.c_str());
+        const gguf_context * ctx = meta.get();
+        const int kid = gguf_find_key(ctx, key.c_str());
 
-        if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
+        if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
             if (required) {
                 throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
             }
@@ -327,22 +340,32 @@ namespace GGUFMeta {
         }
 
         struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
+            GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
 
         switch (arr_info.gt) {
             case GGUF_TYPE_UINT32:
-            case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,  int32_t>::value) ||
-                                                (std::is_same<T, uint32_t>::value)); break;
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,    float>::value)); break;
+            case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,     int32_t>::value) ||
+                                                (std::is_same<T,    uint32_t>::value)); break;
+            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,       float>::value)); break;
+            case GGUF_TYPE_STRING:  GGML_ASSERT((std::is_same<T, std::string>::value)); break;
             default:
-                throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
+                throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
         }
 
         if (arr_info.length > N_MAX) {
             throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
         }
 
-        std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
+        if constexpr (std::is_same<T, std::string>::value) {
+            const size_t n_items = gguf_get_arr_n(ctx, kid);
+
+            for (size_t i = 0; i < n_items; i++) {
+                const T value = gguf_get_arr_str(ctx, kid, i);
+                result[i] = value;
+            }
+        } else {
+            std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
+        }
 
         return true;
     }
@@ -352,6 +375,8 @@ namespace GGUFMeta {
         return get_arr(llm_kv(kid), result, required);
     }
 
+    template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
+
     template<typename T>
     bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
         auto it = kv_overrides.find(key);
index afef8487030fbdf70c0d47981e50ae503a28fc90..915d5a927c6358b09ad4d979758e3514b0c97699 100644 (file)
@@ -543,6 +543,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
     uint32_t n_vocab = 0;
     ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
 
+    // for classifier models
+    ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
+    if (!classifier_labels.empty()) {
+        hparams.n_cls_out = classifier_labels.size();
+    }
+
     // arch-specific KVs
     switch (arch) {
         case LLM_ARCH_LLAMA:
@@ -686,7 +692,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
                 ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
                 ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type, false);
-                ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
 
                 switch (hparams.n_layer) {
                     case 3:
@@ -4362,6 +4367,15 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
         LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
         LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
+
+        if (!classifier_labels.empty()) {
+            LLAMA_LOG_INFO("%s: n_cls_out        = %u\n", __func__, hparams.n_cls_out);
+
+            size_t i = 0;
+            for (auto label : classifier_labels) {
+                LLAMA_LOG_INFO("%s: cls_label[%2zu]    = %s\n", __func__, i++, label.c_str());
+            }
+        }
     }
 
     LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, type_name().c_str());
@@ -13602,6 +13616,18 @@ int32_t llama_model_n_swa(const llama_model * model) {
     return model->hparams.n_swa;
 }
 
+uint32_t llama_model_n_cls_out(const struct llama_model * model) {
+    return model->hparams.n_cls_out;
+}
+
+const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
+    if (i < model->classifier_labels.size()) {
+        return model->classifier_labels[i].c_str();
+    }
+
+    return nullptr;
+}
+
 // deprecated
 int32_t llama_n_ctx_train(const llama_model * model) {
     return llama_model_n_ctx_train(model);
index cbea2cb331b626f6ca2f829a186ec0822b20ce76..18b714620bbcf899e23828eb9653fcb80b728e49 100644 (file)
@@ -329,6 +329,9 @@ struct llama_model {
     llama_hparams hparams = {};
     llama_vocab   vocab;
 
+    // for classifier models
+    std::vector<std::string> classifier_labels;
+
     struct ggml_tensor * tok_embd   = nullptr;
     struct ggml_tensor * type_embd  = nullptr;
     struct ggml_tensor * pos_embd   = nullptr;