]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : check all graph nodes when searching for result_embd_pooled (#8956)
authorfairydreaming <redacted>
Sun, 11 Aug 2024 08:35:26 +0000 (10:35 +0200)
committerGitHub <redacted>
Sun, 11 Aug 2024 08:35:26 +0000 (10:35 +0200)
Co-authored-by: Stanisław Szymczyk <redacted>
src/llama.cpp

index e0fe8013b0ad2ea3a7f2ce9e27607ee832f83e23..aaf8db496ecbdc73a753f3577896bac638778e5c 100644 (file)
@@ -14722,12 +14722,15 @@ static int llama_decode_internal(
             res  = nullptr;
             embd = nullptr;
         } else if (cparams.embeddings) {
-            res = nullptr; // do not extract logits for embedding case
-            embd = gf->nodes[gf->n_nodes - 1];
-            if (strcmp(embd->name, "result_embd_pooled") != 0) {
-                embd = gf->nodes[gf->n_nodes - 2];
+            res  = nullptr; // do not extract logits for embedding case
+            embd = nullptr;
+            for (int i = gf->n_nodes - 1; i >= 0; --i) {
+                if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
+                    embd = gf->nodes[i];
+                    break;
+                }
             }
-            GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
+            GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
         } else {
             embd = nullptr; // do not extract embeddings when not needed
             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");