]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
embeddings: fix extraction of CLS pooling results (#14927)
authorDouglas Hanley <redacted>
Wed, 30 Jul 2025 05:25:05 +0000 (00:25 -0500)
committerGitHub <redacted>
Wed, 30 Jul 2025 05:25:05 +0000 (08:25 +0300)
* embeddings: fix extraction of CLS pooling results

* merge RANK pooling into CLS case for inputs

src/llama-graph.cpp

index 1b9cc4aec0632f93b97377e114d6ccd60f17652d..702192b79df6ece99d152eaac00fcb0a8f39f009 100644 (file)
@@ -188,38 +188,23 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
 
 void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
     const int64_t n_tokens     = ubatch->n_tokens;
-    const int64_t n_seq_tokens = ubatch->n_seq_tokens;
     const int64_t n_seqs_unq   = ubatch->n_seqs_unq;
 
     if (cparams.embeddings && (
-            cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
-            cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
-        )) {
+        cparams.pooling_type == LLAMA_POOLING_TYPE_CLS  ||
+        cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
+        cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
+    )) {
         GGML_ASSERT(cls);
         GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
 
         uint32_t * data = (uint32_t *) cls->data;
         memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
 
-        for (int i = 0; i < n_tokens; i += n_seq_tokens) {
-            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
-                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
-                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
-
-                data[seq_idx] = i;
-            }
-        }
-    }
-
-    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
-        GGML_ASSERT(cls);
-        GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
-
-        uint32_t * data = (uint32_t *) cls->data;
-        memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
+        std::vector<int> target_pos(n_seqs_unq, -1);
+        std::vector<int> target_row(n_seqs_unq, -1);
 
-        std::vector<int> last_pos(n_seqs_unq, -1);
-        std::vector<int> last_row(n_seqs_unq, -1);
+        bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
 
         for (int i = 0; i < n_tokens; ++i) {
             const llama_pos pos = ubatch->pos[i];
@@ -228,16 +213,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
                 const llama_seq_id seq_id  = ubatch->seq_id[i][s];
                 const int32_t      seq_idx = ubatch->seq_idx[seq_id];
 
-                if (pos >= last_pos[seq_idx]) {
-                    last_pos[seq_idx] = pos;
-                    last_row[seq_idx] = i;
+                if (
+                    (target_pos[seq_idx] == -1) ||
+                    ( last && pos >= target_pos[seq_idx]) ||
+                    (!last && pos <  target_pos[seq_idx])
+                ) {
+                    target_pos[seq_idx] = pos;
+                    target_row[seq_idx] = i;
                 }
             }
         }
 
         for (int s = 0; s < n_seqs_unq; ++s) {
-            if (last_row[s] >= 0) {
-                data[s] = last_row[s];
+            if (target_row[s] >= 0) {
+                data[s] = target_row[s];
             }
         }
     }