}
};
+static bool has_pooling(llama_context * ctx) {
+ switch (llama_pooling_type(ctx)) {
+ case LLAMA_POOLING_TYPE_NONE:
+ case LLAMA_POOLING_TYPE_UNSPECIFIED:
+ return false;
+ default:
+ return true;
+ }
+}
+
struct output_data {
float * data_ptr = nullptr;
int data_size = 0;
std::string type_suffix;
- std::vector<float> storage;
+ std::vector<float> embd_norm;
std::string prompt;
std::vector<llama_token> tokens;
prompt = params.prompt;
if (params.embedding) {
- const int n_embd = llama_model_n_embd_out(model);
- const bool pooling_enabled = llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE;
- const int n_embd_count = pooling_enabled ? 1 : tokens.size();
- const int n_embeddings = n_embd * n_embd_count;
-
- float * embeddings;
- if (pooling_enabled) {
- embeddings = llama_get_embeddings_seq(ctx, 0);
- storage.resize(n_embeddings);
- common_embd_normalize(embeddings, storage.data(), n_embeddings, params.embd_normalize);
- embeddings = storage.data();
- } else {
- embeddings = llama_get_embeddings(ctx);
+ const int n_embd = llama_model_n_embd_out(model);
+ const bool pooling = has_pooling(ctx);
+ const int n_embd_count = pooling ? 1 : tokens.size();
+ const int n_floats = n_embd * n_embd_count;
+
+ float * embd_raw = pooling ? llama_get_embeddings_seq(ctx, 0) : llama_get_embeddings(ctx);
+ if (embd_raw == nullptr) {
+ throw std::runtime_error("failed to get embeddings from the model");
}
- data_ptr = embeddings;
- data_size = n_embeddings;
+ LOG_DBG("pooling_enabled: %s\n", pooling ? "true" : "false");
+ LOG_DBG("n_embd: %d\n", n_embd);
+ LOG_DBG("n_floats: %d\n", n_floats);
+ LOG_DBG("n_embd_count: %d\n", n_embd_count);
+
+ data_ptr = embd_raw;
+ data_size = n_floats;
type_suffix = "-embeddings";
+
+ if (params.embd_normalize >= 0) {
+ embd_norm.resize(n_floats);
+ for (int i = 0; i < n_embd_count; i++) {
+ common_embd_normalize(embd_raw+i*n_embd, embd_norm.data()+i*n_embd, n_embd, params.embd_normalize);
+ }
+ data_ptr = embd_norm.data();
+ }
} else {
const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1);
const int n_logits = llama_vocab_n_tokens(vocab);