]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add getters for n_threads/n_threads_batch (#7464)
authorDaniel Bevenius <redacted>
Thu, 23 May 2024 12:29:26 +0000 (14:29 +0200)
committerGitHub <redacted>
Thu, 23 May 2024 12:29:26 +0000 (15:29 +0300)
* llama : add getters for n_threads/n_threads_batch

This commit adds two new functions to the llama API. The functions
can be used to get the number of threads used for generating a single
token and the number of threads used for prompt and batch processing
(multiple tokens).

The motivation for this is that we want to be able to get the number of
threads that the a context is using. The main use case is for a
testing/verification that the number of threads is set correctly.

Signed-off-by: Daniel Bevenius <redacted>
* squash! llama : add getters for n_threads/n_threads_batch

Rename the getters to llama_n_threads and llama_n_threads_batch.

Signed-off-by: Daniel Bevenius <redacted>
---------

Signed-off-by: Daniel Bevenius <redacted>
llama.cpp
llama.h

index 1f9e10eedde9ea0cc846ea38c9a1c5accd0d97d2..e540c1b392eaa69907abe3458022f18e1983f6d6 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -17410,6 +17410,14 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_
     ctx->cparams.n_threads_batch = n_threads_batch;
 }
 
+uint32_t llama_n_threads(struct llama_context * ctx) {
+    return ctx->cparams.n_threads;
+}
+
+uint32_t llama_n_threads_batch(struct llama_context * ctx) {
+    return ctx->cparams.n_threads_batch;
+}
+
 void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
     ctx->abort_callback      = abort_callback;
     ctx->abort_callback_data = abort_callback_data;
diff --git a/llama.h b/llama.h
index b7bf2afcb403e0824e4d0594a8c6646e1775247d..16cece5db0e78f55b56ac0ab87cd4c67ff1e77b3 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -759,6 +759,12 @@ extern "C" {
     // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
     LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
 
+    // Get the number of threads used for generation of a single token.
+    LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx);
+
+    // Get the number of threads used for prompt and batch processing (multiple token).
+    LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
+
     // Set whether to use causal attention or not
     // If set to true, the model will only attend to the past tokens
     LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);