]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
tts : add guide tokens support (#11186)
authorLostRuins Concedo <redacted>
Sat, 18 Jan 2025 10:20:57 +0000 (18:20 +0800)
committerGitHub <redacted>
Sat, 18 Jan 2025 10:20:57 +0000 (12:20 +0200)
* Added the ability to use guide tokens for OuteTTS, greatly improving TTS recitation accuracy over long input sequences.

* applied linting suggestions, updated to latest llama_vocab changes, added a safety check, added newline to guide token start

common/arg.cpp
common/common.h
examples/tts/tts.cpp

index 9069950eb093936f5019613482e851faeba1adbc..dede335fbc3fdb5e28c96dfd2089f6913487ab88 100644 (file)
@@ -2254,6 +2254,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.vocoder.model = value;
         }
     ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
+     add_opt(common_arg(
+        {"--tts-use-guide-tokens"},
+        "Use guide tokens to improve TTS word recall",
+        [](common_params & params) {
+            params.vocoder.use_guide_tokens = true;
+        }
+    ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
 
     // model-specific
     add_opt(common_arg(
index 691141d6b6b2cc31d9b9646e9e6aa20f09f63002..3bcc637cc800ae5fb6029403a158f452e0490e4f 100644 (file)
@@ -184,6 +184,8 @@ struct common_params_vocoder {
 
     std::string model     = ""; // model path                                                // NOLINT
     std::string model_url = ""; // model url to download                                     // NOLINT
+
+    bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy            // NOLINT
 };
 
 struct common_params {
index 5a91611811b4edbfef87fccd3b849d3294f830ef..f78f763033a2379298f2b9833f7dec94ff93da32 100644 (file)
@@ -425,6 +425,33 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
     prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
 }
 
+static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
+    const std::string& delimiter = "<|text_sep|>";
+
+    std::vector<llama_token> result;
+    size_t start = 0;
+    size_t end = str.find(delimiter);
+
+    //first token is always a newline, as it was not previously added
+    result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
+
+    while (end != std::string::npos) {
+        std::string current_word = str.substr(start, end - start);
+        auto tmp = common_tokenize(vocab, current_word, false, true);
+        result.push_back(tmp[0]);
+        start = end + delimiter.length();
+        end = str.find(delimiter, start);
+    }
+
+    // Add the last part
+    std::string current_word = str.substr(start);
+    auto tmp = common_tokenize(vocab, current_word, false, true);
+    if (tmp.size() > 0) {
+        result.push_back(tmp[0]);
+    }
+    return result;
+}
+
 int main(int argc, char ** argv) {
     common_params params;
 
@@ -494,6 +521,7 @@ int main(int argc, char ** argv) {
     const auto t_main_start = ggml_time_us();
 
     std::vector<llama_token> codes;
+    std::vector<llama_token> guide_tokens;
 
     // process prompt and generate voice codes
     {
@@ -508,6 +536,9 @@ int main(int argc, char ** argv) {
         // convert the input text into the necessary format expected by OuteTTS
         {
             std::string prompt_clean = process_text(params.prompt);
+            if (params.vocoder.use_guide_tokens) {
+                guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
+            }
 
             LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
 
@@ -717,6 +748,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
         int n_past   = batch.n_tokens;
         int n_decode = 0;
 
+        bool next_token_uses_guide_token = true;
+
         while (n_decode <= n_predict) {
             // prepare the next batch
             common_batch_clear(batch);
@@ -728,7 +761,17 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
                     continue;
                 }
 
-                const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
+                llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
+
+                //guide tokens help prevent hallucinations by forcing the TTS to use the correct word
+                if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) {
+                    llama_token guide_token = guide_tokens[0];
+                    guide_tokens.erase(guide_tokens.begin());
+                    new_token_id = guide_token; //ensure correct word fragment is used
+                }
+
+                //this is the token id that always precedes a new word
+                next_token_uses_guide_token = (new_token_id == 198);
 
                 common_sampler_accept(smpl[i], new_token_id, true);