]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
feat: remove a sampler from a chain (#9445)
authorGilad S. <redacted>
Fri, 13 Sep 2024 01:54:49 +0000 (04:54 +0300)
committerGitHub <redacted>
Fri, 13 Sep 2024 01:54:49 +0000 (03:54 +0200)
* feat: remove a sampler from a chain

* fix: return removed sampler

* fix: safer casting

include/llama.h
src/llama-sampling.cpp

index 405af912c46868be5fb41e2e21e9439ab5c052a5..744ef9d900abf8f75e1c4121b1dc130d09c6ed57 100644 (file)
@@ -1056,6 +1056,9 @@ extern "C" {
     LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
     LLAMA_API int                    llama_sampler_chain_n  (const struct llama_sampler * chain);
 
+    // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
+    LLAMA_API struct llama_sampler * llama_sampler_chain_remove(   struct llama_sampler * chain, int32_t i);
+
     // available samplers:
 
     LLAMA_API struct llama_sampler * llama_sampler_init_greedy     (void);
index fd1b7f9196f373af1f7a1a3897de3c707f28425d..c828dc359b58bca9a52736d363de3a19478548f6 100644 (file)
@@ -349,13 +349,26 @@ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler
 struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
     const auto * p = (const llama_sampler_chain *) chain->ctx;
 
-    if (i < 0 || i >= (int32_t) p->samplers.size()) {
+    if (i < 0 || (size_t) i >= p->samplers.size()) {
         return nullptr;
     }
 
     return p->samplers[i];
 }
 
+struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
+    auto * p = (llama_sampler_chain *) chain->ctx;
+
+    if (i < 0 || (size_t) i >= p->samplers.size()) {
+        return nullptr;
+    }
+
+    auto * result = p->samplers[i];
+    p->samplers.erase(p->samplers.begin() + i);
+
+    return result;
+}
+
 int llama_sampler_chain_n(const struct llama_sampler * chain) {
     const auto * p = (const llama_sampler_chain *) chain->ctx;