]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
server : add no-speech threshold parameter and functionality (#2654)
authorSacha Arbonel <redacted>
Sat, 21 Dec 2024 15:00:08 +0000 (16:00 +0100)
committerGitHub <redacted>
Sat, 21 Dec 2024 15:00:08 +0000 (17:00 +0200)
examples/server/server.cpp
include/whisper.h
src/whisper.cpp

index 0608bb6bd1a50fb275941e9b990749574085edc4..a2ef726fe662091198091f163120417f94b92f9c 100644 (file)
@@ -61,6 +61,7 @@ struct whisper_params {
     float logprob_thold   = -1.00f;
     float temperature     =  0.00f;
     float temperature_inc =  0.20f;
+    float no_speech_thold = 0.6f;
 
     bool debug_mode      = false;
     bool translate       = false;
@@ -137,6 +138,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  --inference-path PATH,         [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
     fprintf(stderr, "  --convert,                     [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
     fprintf(stderr, "  -sns,      --suppress-nst      [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
+    fprintf(stderr, "  -nth N,    --no-speech-thold N [%-7.2f] no speech threshold\n",   params.no_speech_thold);
     fprintf(stderr, "\n");
 }
 
@@ -182,6 +184,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
         else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu         = false; }
         else if (arg == "-fa"   || arg == "--flash-attn")      { params.flash_attn      = true; }
         else if (arg == "-sns"  || arg == "--suppress-nst")    { params.suppress_nst    = true; }
+        else if (arg == "-nth"  || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
+
         // server params
         else if (                  arg == "--port")            { sparams.port        = std::stoi(argv[++i]); }
         else if (                  arg == "--host")            { sparams.hostname    = argv[++i]; }
@@ -790,6 +794,7 @@ int main(int argc, char ** argv) {
             wparams.beam_search.beam_size = params.beam_size;
 
             wparams.temperature      = params.temperature;
+            wparams.no_speech_thold = params.no_speech_thold;
             wparams.temperature_inc  = params.temperature_inc;
             wparams.entropy_thold    = params.entropy_thold;
             wparams.logprob_thold    = params.logprob_thold;
@@ -942,7 +947,7 @@ int main(int argc, char ** argv) {
 
                 // TODO compression_ratio and no_speech_prob are not implemented yet
                 // segment["compression_ratio"] = 0;
-                // segment["no_speech_prob"] = 0;
+                segment["no_speech_prob"] = whisper_full_get_segment_no_speech_prob(ctx, i);
 
                 jres["segments"].push_back(segment);
             }
index 6e0db505ee18f11b05a4b095f0c41b28f963e6e4..03ce110da70664322cbb64bf799498db1f648d14 100644 (file)
@@ -665,6 +665,8 @@ extern "C" {
 
     WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
 
+    // Get the no_speech probability for the specified segment
+    WHISPER_API float whisper_full_get_segment_no_speech_prob           (struct whisper_context * ctx, int i_segment);
 #ifdef __cplusplus
 }
 #endif
index ff200ea0f8d6b15f4c42a2c99eedc824f38390a6..5a9f3df8ede8d7f97eb88c7f6761ea81953e149a 100644 (file)
@@ -428,6 +428,7 @@ struct whisper_segment {
     int64_t t1;
 
     std::string text;
+    float no_speech_prob;
 
     std::vector<whisper_token_data> tokens;
 
@@ -6147,7 +6148,7 @@ int whisper_full_with_state(
 
                             //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
 
-                            result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
+                            result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
                             for (int j = i0; j <= i; j++) {
                                 result_all.back().tokens.push_back(tokens_cur[j]);
                             }
@@ -6192,7 +6193,7 @@ int whisper_full_with_state(
                         }
                     }
 
-                    result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
+                    result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
                     for (int j = i0; j < (int) tokens_cur.size(); j++) {
                         result_all.back().tokens.push_back(tokens_cur[j]);
                     }
@@ -6459,6 +6460,10 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
     return ctx->state->result_all[i_segment].tokens[i_token].p;
 }
 
+float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
+    return ctx->state->result_all[i_segment].no_speech_prob;
+}
+
 // =================================================================================================
 
 //