]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
main : fix sampling time + add max_context parameter
authorGeorgi Gerganov <redacted>
Sat, 29 Oct 2022 06:42:14 +0000 (09:42 +0300)
committerGeorgi Gerganov <redacted>
Sat, 29 Oct 2022 16:37:19 +0000 (19:37 +0300)
examples/main/main.cpp
whisper.cpp
whisper.h

index 91e4a375522ee53b2a281141a3bca0b03eaf46ad..b0d576f1364ec99b5d986268e86be9307c443b06 100644 (file)
@@ -42,6 +42,7 @@ struct whisper_params {
     int32_t n_threads   = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t offset_t_ms = 0;
     int32_t offset_n    = 0;
+    int32_t max_context = -1;
 
     bool verbose              = false;
     bool translate            = false;
@@ -77,6 +78,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.offset_t_ms = std::stoi(argv[++i]);
         } else if (arg == "-on" || arg == "--offset-n") {
             params.offset_n = std::stoi(argv[++i]);
+        } else if (arg == "-mc" || arg == "--max-context") {
+            params.max_context = std::stoi(argv[++i]);
         } else if (arg == "-v" || arg == "--verbose") {
             params.verbose = true;
         } else if (arg == "--translate") {
@@ -127,6 +130,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "  -t N,     --threads N      number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -ot N,    --offset-t N     time offset in milliseconds (default: %d)\n", params.offset_t_ms);
     fprintf(stderr, "  -on N,    --offset-n N     segment index offset (default: %d)\n", params.offset_n);
+    fprintf(stderr, "  -mc N,    --max-context N  maximum number of text context tokens to store (default: max)\n");
     fprintf(stderr, "  -v,       --verbose        verbose output\n");
     fprintf(stderr, "            --translate      translate from source language to english\n");
     fprintf(stderr, "  -otxt,    --output-txt     output result in a text file\n");
@@ -380,6 +384,8 @@ int main(int argc, char ** argv) {
             wparams.translate            = params.translate;
             wparams.language             = params.language.c_str();
             wparams.n_threads            = params.n_threads;
+            wparams.n_processors         = 1;
+            wparams.n_max_text_ctx       = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
             wparams.offset_ms            = params.offset_t_ms;
 
             // this callback is called on each new segment
index ee8d994fd0ef2945c5696892160ac4e78af71b7a..168182f36480d2a50e73029d44e7a061d2db1e68 100644 (file)
@@ -211,14 +211,6 @@ struct whisper_vocab {
     }
 };
 
-struct whisper_token_data {
-    whisper_token id;  // token id
-    whisper_token tid; // forced timestamp token id
-
-    float p;  // probability of the token
-    float pt; // probability of the timestamp token
-};
-
 struct whisper_segment {
     int64_t t0;
     int64_t t1;
@@ -2219,7 +2211,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
     return 0;
 }
 
-whisper_token whisper_sample_best(struct whisper_context * ctx) {
+whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
     const int64_t t_start_sample_us = ggml_time_us();
 
     // TODO: simplify
@@ -2227,7 +2219,7 @@ whisper_token whisper_sample_best(struct whisper_context * ctx) {
 
     ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
 
-    return res.id;
+    return res;
 }
 
 whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
@@ -2330,8 +2322,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.strategy             =*/ WHISPER_SAMPLING_GREEDY,
 
                     /*.n_threads            =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
-                    /*.offset_ms            =*/ 0,
                     /*.n_processors         =*/ 1,
+                    /*.n_max_text_ctx       =*/ 16384,
+                    /*.offset_ms            =*/ 0,
 
                     /*.translate            =*/ false,
                     /*.no_context           =*/ false,
@@ -2362,8 +2355,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
                     /*.strategy             =*/ WHISPER_SAMPLING_BEAM_SEARCH,
 
                     /*.n_threads            =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
-                    /*.offset_ms            =*/ 0,
                     /*.n_processors         =*/ 1,
+                    /*.n_max_text_ctx       =*/ 16384,
+                    /*.offset_ms            =*/ 0,
 
                     /*.translate            =*/ false,
                     /*.no_context           =*/ false,
@@ -2470,7 +2464,7 @@ int whisper_full(
 
         // if we have already generated some text, use it as a prompt to condition the next generation
         if (prompt_past.size() > 0) {
-            int n_take = std::min(whisper_n_text_ctx(ctx)/2, int(prompt_past.size()));
+            int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
 
             prompt = { whisper_token_prev(ctx) };
             prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
@@ -2512,7 +2506,7 @@ int whisper_full(
             // feel free to experiment!
             //
             {
-                auto token = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
+                auto token = whisper_sample_best(ctx);
 
                 if (i == 0) {
                     token.tid = whisper_token_beg(ctx);
index c918e98b3f44d1bb8aee714ea22bed357bf6f1cf..9368e25ca0aed1be4733d28d1df1809f36a9e73d 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -68,6 +68,14 @@ extern "C" {
 
     typedef int whisper_token;
 
+    struct whisper_token_data {
+        whisper_token id;  // token id
+        whisper_token tid; // forced timestamp token id
+
+        float p;  // probability of the token
+        float pt; // probability of the timestamp token
+    };
+
     // Allocates all memory needed for the model and loads the model from the given file.
     // Returns NULL on failure.
     WHISPER_API struct whisper_context * whisper_init(const char * path_model);
@@ -122,7 +130,7 @@ extern "C" {
     // You can also implement your own sampling method using the whisper_get_probs() function.
     // whisper_sample_best() returns the token with the highest probability
     // whisper_sample_timestamp() returns the most probable timestamp token
-    WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx);
+    WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
 
     // Return the id of the specified language, returns -1 if not found
@@ -171,8 +179,9 @@ extern "C" {
         enum whisper_sampling_strategy strategy;
 
         int n_threads;
-        int offset_ms;
         int n_processors;
+        int n_max_text_ctx;
+        int offset_ms;
 
         bool translate;
         bool no_context;