fprintf(stderr, "\n");
}
- if (params.embedding){
- if (embd_inp.size() > 0) {
- if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads)) {
- fprintf(stderr, "%s : failed to eval\n", __func__);
- return 1;
- }
+ if (embd_inp.size() > (size_t)params.n_ctx) {
+ fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
+ __func__, embd_inp.size(), params.n_ctx);
+ return 1;
+ }
+
+ while (!embd_inp.empty()) {
+ int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
+ if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) {
+ fprintf(stderr, "%s : failed to eval\n", __func__);
+ return 1;
}
+ n_past += n_tokens;
+ embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
+ }
- const int n_embd = llama_n_embd(ctx);
- const auto embeddings = llama_get_embeddings(ctx);
+ const int n_embd = llama_n_embd(ctx);
+ const auto embeddings = llama_get_embeddings(ctx);
- for (int i = 0; i < n_embd; i++) {
- printf("%f ", embeddings[i]);
- }
- printf("\n");
+ for (int i = 0; i < n_embd; i++) {
+ printf("%f ", embeddings[i]);
}
+ printf("\n");
llama_print_timings(ctx);
llama_free(ctx);