]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
embedding : adjust `n_ubatch` value (#6296)
authorMinsoo Cheong <redacted>
Tue, 26 Mar 2024 09:11:46 +0000 (18:11 +0900)
committerGitHub <redacted>
Tue, 26 Mar 2024 09:11:46 +0000 (11:11 +0200)
* embedding: assign `n_ubatch` value, print error on `n_batch` overflow

* Update examples/embedding/embedding.cpp

Co-authored-by: Xuan Son Nguyen <redacted>
* use %ld instead of %lld

* Revert "use %ld instead of %lld"

This reverts commit ea753ede90a86a0699f65878cc8e2020ff5eabb8.

---------

Co-authored-by: Xuan Son Nguyen <redacted>
examples/embedding/embedding.cpp

index cbf9aa2b560dde1036373a213e6c06d185a9254f..9aede7fadfe31f607516334ab814eb9e9c71cf76 100644 (file)
@@ -61,6 +61,8 @@ int main(int argc, char ** argv) {
     }
 
     params.embedding = true;
+    // For non-causal models, batch size must be equal to ubatch size
+    params.n_ubatch = params.n_batch;
 
     print_build_info();
 
@@ -114,7 +116,9 @@ int main(int argc, char ** argv) {
     for (const auto & prompt : prompts) {
         auto inp = ::llama_tokenize(ctx, prompt, true, false);
         if (inp.size() > n_batch) {
-            inp.resize(n_batch);
+            fprintf(stderr, "%s: error: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
+                    __func__, (long long int) inp.size(), (long long int) n_batch);
+            return 1;
         }
         inputs.push_back(inp);
     }