]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
whisper : improve handling of prompts (whisper/1981)
authorGeorgi Gerganov <redacted>
Mon, 25 Mar 2024 12:48:19 +0000 (14:48 +0200)
committerGeorgi Gerganov <redacted>
Wed, 27 Mar 2024 11:35:37 +0000 (13:35 +0200)
* whisper : improve handling of prompts

* whisper : add whisper_token_count helper

examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h

index caa800b3df8ba45445809a2f7873099866e509ca..415c3b33de724836cf20f5d1f9cc00c4ab27c301 100644 (file)
@@ -207,7 +207,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -nt,       --no-timestamps     [%-7s] do not print timestamps\n",                        params.no_timestamps ? "true" : "false");
     fprintf(stderr, "  -l LANG,   --language LANG     [%-7s] spoken language ('auto' for auto-detect)\n",       params.language.c_str());
     fprintf(stderr, "  -dl,       --detect-language   [%-7s] exit after automatically detecting language\n",    params.detect_language ? "true" : "false");
-    fprintf(stderr, "             --prompt PROMPT     [%-7s] initial prompt\n",                                 params.prompt.c_str());
+    fprintf(stderr, "             --prompt PROMPT     [%-7s] initial prompt (max n_text_ctx/2 tokens)\n",       params.prompt.c_str());
     fprintf(stderr, "  -m FNAME,  --model FNAME       [%-7s] model path\n",                                     params.model.c_str());
     fprintf(stderr, "  -f FNAME,  --file FNAME        [%-7s] input WAV file path\n",                            "");
     fprintf(stderr, "  -oved D,   --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n",  params.openvino_encode_device.c_str());
index 8cb7e0ddf747c620e0e30083d1ca019583c72e88..d50c788b3c6a2dc9e696414cee56ecfd9aaa36b8 100644 (file)
@@ -3721,7 +3721,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
 
     if (n_max_tokens < (int) res.size()) {
         WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
-        return -1;
+        return -(int) res.size();
     }
 
     for (int i = 0; i < (int) res.size(); i++) {
@@ -3731,6 +3731,10 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
     return res.size();
 }
 
+int whisper_token_count(struct whisper_context * ctx, const char * text) {
+    return -whisper_tokenize(ctx, text, NULL, 0);
+}
+
 int whisper_lang_max_id() {
     auto max_id = 0;
     for (const auto & kv : g_lang) {
@@ -5313,7 +5317,12 @@ int whisper_full_with_state(
         // initial prompt
         if (!params.prompt_tokens && params.initial_prompt) {
             prompt_tokens.resize(1024);
-            prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
+            int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
+            if (n_needed < 0) {
+                prompt_tokens.resize(-n_needed);
+                n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
+            }
+            prompt_tokens.resize(n_needed);
             params.prompt_tokens   = prompt_tokens.data();
             params.prompt_n_tokens = prompt_tokens.size();
         }
index 2754337f14307c577bcc43a3cd0d8b489083a37a..bd8d8df828a805fca8445b20edca438da5dc1de2 100644 (file)
@@ -337,7 +337,7 @@ extern "C" {
     // Convert the provided text into tokens.
     // The tokens pointer must be large enough to hold the resulting tokens.
     // Returns the number of tokens on success, no more than n_max_tokens
-    // Returns -1 on failure
+    // Returns a negative number on failure - the number of tokens that would have been returned
     // TODO: not sure if correct
     WHISPER_API int whisper_tokenize(
             struct whisper_context * ctx,
@@ -345,6 +345,10 @@ extern "C" {
                      whisper_token * tokens,
                                int   n_max_tokens);
 
+    // Return the number of tokens in the provided text
+    // Equivalent to: -whisper_tokenize(ctx, text, NULL, 0)
+    int whisper_token_count(struct whisper_context * ctx, const char * text);
+
     // Largest language id (i.e. number of available languages - 1)
     WHISPER_API int whisper_lang_max_id();
 
@@ -503,6 +507,8 @@ extern "C" {
 
         // tokens to provide to the whisper decoder as initial prompt
         // these are prepended to any existing text context from a previous call
+        // use whisper_tokenize() to convert text to tokens
+        // maximum of whisper_n_text_ctx()/2 tokens are used (typically 224)
         const char * initial_prompt;
         const whisper_token * prompt_tokens;
         int prompt_n_tokens;