]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add openai-style logit_bias support (#14946)
authorLukas Straub <redacted>
Thu, 31 Jul 2025 12:08:23 +0000 (14:08 +0200)
committerGitHub <redacted>
Thu, 31 Jul 2025 12:08:23 +0000 (14:08 +0200)
Signed-off-by: Lukas Straub <redacted>
tools/server/README.md
tools/server/server.cpp
tools/server/tests/unit/test_chat_completion.py
tools/server/tests/unit/test_completion.py

index f3f4caed85cf5a0928d0fcac95a828043d9be0c5..87cef75730afbf55a26bb122384d6bfda6eb40c8 100644 (file)
@@ -469,7 +469,7 @@ These words will not be included in the completion, so make sure to add them to
 
 `ignore_eos`: Ignore end of stream token and continue generating.  Default: `false`
 
-`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]`
+`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. For compatibility with the OpenAI API, a JSON object {"<string or token id>": bias, ...} can also be passed. Default: `[]`
 
 `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0`
 
index 2e4c40af7839a6f1936af678b4023cc096fa4488..9a9b0444746f14cb72d4b88aefd8a46a7529c678 100644 (file)
@@ -473,6 +473,33 @@ struct server_task {
                         }
                     }
                 }
+           } else if (logit_bias != data.end() && logit_bias->is_object()) {
+                const int n_vocab = llama_vocab_n_tokens(vocab);
+                for (const auto & el : logit_bias->items()) {
+                    float bias;
+                    const auto & key = el.key();
+                    const auto & value = el.value();
+                    if (value.is_number()) {
+                        bias = value.get<float>();
+                    } else if (value.is_boolean() && !value.get<bool>()) {
+                        bias = -INFINITY;
+                    } else {
+                        continue;
+                    }
+
+                    char *end;
+                    llama_token tok = strtol(key.c_str(), &end, 10);
+                    if (*end == 0) {
+                        if (tok >= 0 && tok < n_vocab) {
+                            params.sampling.logit_bias.push_back({tok, bias});
+                        }
+                    } else {
+                        auto toks = common_tokenize(vocab, key, false);
+                        for (auto tok : toks) {
+                            params.sampling.logit_bias.push_back({tok, bias});
+                        }
+                    }
+                }
             }
 
             params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
index 7ee9a1651400daf5b5e6c25d3af09aca1c5f8c64..6c6f64f5e2ec4467b52a8af6e81c35205d57a22c 100644 (file)
@@ -351,3 +351,32 @@ def test_logprobs_stream():
                     assert token.top_logprobs is not None
                     assert len(token.top_logprobs) > 0
     assert aggregated_text == output_text
+
+
+def test_logit_bias():
+    global server
+    server.start()
+
+    exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
+
+    res = server.make_request("POST", "/tokenize", data={
+        "content": " " + " ".join(exclude) + " ",
+    })
+    assert res.status_code == 200
+    tokens = res.body["tokens"]
+    logit_bias = {tok: -100 for tok in tokens}
+
+    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
+    res = client.chat.completions.create(
+        model="gpt-3.5-turbo-instruct",
+        temperature=0.0,
+        messages=[
+            {"role": "system", "content": "Book"},
+            {"role": "user", "content": "What is the best book"},
+        ],
+        max_tokens=64,
+        logit_bias=logit_bias
+    )
+    output_text = res.choices[0].message.content
+    assert output_text
+    assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
index f6909e9ae788438ca3125da421642c58e1f83c46..be3a0052c64feefd194da01e81f590c18e04815f 100644 (file)
@@ -444,6 +444,39 @@ def test_n_probs_post_sampling():
         assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
 
 
+@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
+def test_logit_bias(tokenize, openai_style):
+    global server
+    server.start()
+
+    exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
+
+    logit_bias = []
+    if tokenize:
+        res = server.make_request("POST", "/tokenize", data={
+            "content": " " + " ".join(exclude) + " ",
+        })
+        assert res.status_code == 200
+        tokens = res.body["tokens"]
+        logit_bias = [[tok, -100] for tok in tokens]
+
+    else:
+        logit_bias = [[" " + tok + " ", -100] for tok in exclude]
+
+    if openai_style:
+        logit_bias = {el[0]: -100 for el in logit_bias}
+
+    res = server.make_request("POST", "/completion", data={
+        "n_predict": 64,
+        "prompt": "What is the best book",
+        "logit_bias": logit_bias,
+        "temperature": 0.0
+    })
+    assert res.status_code == 200
+    output_text = res.body["content"]
+    assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
+
+
 def test_cancel_request():
     global server
     server.n_ctx = 4096