res = nullptr;
embd = nullptr;
} else if (cparams.embeddings) {
- res = nullptr; // do not extract logits for embedding case
- embd = gf->nodes[gf->n_nodes - 1];
- if (strcmp(embd->name, "result_embd_pooled") != 0) {
- embd = gf->nodes[gf->n_nodes - 2];
+ res = nullptr; // do not extract logits for embedding case
+ embd = nullptr;
+ for (int i = gf->n_nodes - 1; i >= 0; --i) {
+ if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
+ embd = gf->nodes[i];
+ break;
+ }
}
- GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
+ GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
} else {
embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");