]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : fix for swedish umlauts + expose model inference settings in talk-llama...
authormatteng1 <redacted>
Mon, 26 May 2025 05:57:39 +0000 (07:57 +0200)
committerGitHub <redacted>
Mon, 26 May 2025 05:57:39 +0000 (07:57 +0200)
Quick fix for not removing swedish umlauts.

* Update talk-llama.cpp

Expose model inference settings to user instead of hard coding them. Same defaults as previous defaults.

* Update examples/talk-llama/talk-llama.cpp

Co-authored-by: Georgi Gerganov <redacted>
examples/talk-llama/talk-llama.cpp

index 9097c491b610affa6e60c2333d9d89362de5d44c..17ae1c95e11b20a423886f672bdef0755c7a8efd 100644 (file)
@@ -60,7 +60,13 @@ struct whisper_params {
     int32_t max_tokens = 32;
     int32_t audio_ctx  = 0;
     int32_t n_gpu_layers = 999;
-
+    int32_t seed = 0;
+    int32_t top_k = 5;
+    int32_t min_keep = 1;
+    float top_p = 0.80f;
+    float min_p = 0.01f;
+    float temp  = 0.30f;
+    
     float vad_thold  = 0.6f;
     float freq_thold = 100.0f;
 
@@ -102,6 +108,12 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
         else if (arg == "-mt"  || arg == "--max-tokens")     { params.max_tokens     = std::stoi(argv[++i]); }
         else if (arg == "-ac"  || arg == "--audio-ctx")      { params.audio_ctx      = std::stoi(argv[++i]); }
         else if (arg == "-ngl" || arg == "--n-gpu-layers")   { params.n_gpu_layers   = std::stoi(argv[++i]); }
+        else if (arg == "--seed")                            { params.seed           = std::stoi(argv[++i]); }
+        else if (arg == "--top-k")                           { params.top_k          = std::stoi(argv[++i]); }
+        else if (arg == "--min-keep")                        { params.min_keep       = std::stoul(argv[++i]);}
+        else if (arg == "--top-p")                           { params.top_p          = std::stof(argv[++i]); }
+        else if (arg == "--min-p")                           { params.min_p          = std::stof(argv[++i]); }
+        else if (arg == "--temp")                            { params.temp           = std::stof(argv[++i]); }
         else if (arg == "-vth" || arg == "--vad-thold")      { params.vad_thold      = std::stof(argv[++i]); }
         else if (arg == "-fth" || arg == "--freq-thold")     { params.freq_thold     = std::stof(argv[++i]); }
         else if (arg == "-tr"  || arg == "--translate")      { params.translate      = true; }
@@ -150,6 +162,12 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -mt N,    --max-tokens N   [%-7d] maximum number of tokens per audio chunk\n",    params.max_tokens);
     fprintf(stderr, "  -ac N,    --audio-ctx N    [%-7d] audio context size (0 - all)\n",                params.audio_ctx);
     fprintf(stderr, "  -ngl N,   --n-gpu-layers N [%-7d] number of layers to store in VRAM\n",           params.n_gpu_layers);
+    fprintf(stderr, "  --seed N                   [%-7d] seed sampling\n",                               params.seed);
+    fprintf(stderr, "  --top-k N                  [%-7d] top-k sampling (0 = disabled)\n",               params.top_k);
+    fprintf(stderr, "  --min-keep N               [%-7d] minimum number of tokens to keep\n",            params.min_keep);
+    fprintf(stderr, "  --top-p N                  [%-7.2f] top-p sampling\n",                            params.top_p);
+    fprintf(stderr, "  --min-p N                  [%-7.2f] min-p sampling\n",                            params.min_p);
+    fprintf(stderr, "  --temp N                   [%-7.2f] temperature\n",                               params.temp);
     fprintf(stderr, "  -vth N,   --vad-thold N    [%-7.2f] voice activity detection threshold\n",        params.vad_thold);
     fprintf(stderr, "  -fth N,   --freq-thold N   [%-7.2f] high-pass frequency cutoff\n",                params.freq_thold);
     fprintf(stderr, "  -tr,      --translate      [%-7s] translate from source language to english\n",   params.translate ? "true" : "false");
@@ -409,21 +427,16 @@ int main(int argc, char ** argv) {
     llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
 
     // init sampler
-    const float top_k = 5;
-    const float top_p = 0.80f;
-    const float temp  = 0.30f;
-
-    const int seed = 0;
-
     auto sparams = llama_sampler_chain_default_params();
 
     llama_sampler * smpl = llama_sampler_chain_init(sparams);
 
-    if (temp > 0.0f) {
-        llama_sampler_chain_add(smpl, llama_sampler_init_top_k(top_k));
-        llama_sampler_chain_add(smpl, llama_sampler_init_top_p(top_p, 1));
-        llama_sampler_chain_add(smpl, llama_sampler_init_temp (temp));
-        llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed));
+    if (params.temp > 0.0f) {
+        llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.top_k));
+        llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.top_p, params.min_keep));
+        llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.temp));
+        llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.seed));
+        llama_sampler_chain_add(smpl, llama_sampler_init_min_p (params.min_p, params.min_keep));
     } else {
         llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
     }
@@ -615,7 +628,7 @@ int main(int argc, char ** argv) {
                 }
 
                 // remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
-                text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
+                text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9åäöÅÄÖ\\.,\\?!\\s\\:\\'\\-]"), "");
 
                 // take first line
                 text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));