]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
whisper : sync with whisper.cpp
authorGeorgi Gerganov <redacted>
Sat, 8 Oct 2022 15:15:22 +0000 (18:15 +0300)
committerGeorgi Gerganov <redacted>
Sat, 8 Oct 2022 15:15:22 +0000 (18:15 +0300)
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h
src/ggml.c

index 6d1c55dace01eb6966d959b804e0738df820ea55..5362d4a21b6b06cdf1a5b08d3376d8f7cbfc2917 100644 (file)
@@ -5,6 +5,7 @@
 #define DR_WAV_IMPLEMENTATION
 #include "dr_wav.h"
 
+#include <fstream>
 #include <cstdio>
 #include <string>
 #include <thread>
@@ -28,15 +29,20 @@ std::string to_timestamp(int64_t t) {
 struct whisper_params {
     int32_t seed      = -1; // RNG seed, not used currently
     int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
+    int32_t offset_ms = 0;
 
     bool verbose              = false;
     bool translate            = false;
+    bool output_txt           = false;
+    bool output_vtt           = false;
+    bool output_srt           = false;
     bool print_special_tokens = false;
     bool no_timestamps        = false;
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
-    std::string fname_inp = "samples/jfk.wav";
+
+    std::vector<std::string> fname_inp = {};
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -45,10 +51,17 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
     for (int i = 1; i < argc; i++) {
         std::string arg = argv[i];
 
+        if (arg[0] != '-') {
+            params.fname_inp.push_back(arg);
+            continue;
+        }
+
         if (arg == "-s" || arg == "--seed") {
             params.seed = std::stoi(argv[++i]);
         } else if (arg == "-t" || arg == "--threads") {
             params.n_threads = std::stoi(argv[++i]);
+        } else if (arg == "-o" || arg == "--offset") {
+            params.offset_ms = std::stoi(argv[++i]);
         } else if (arg == "-v" || arg == "--verbose") {
             params.verbose = true;
         } else if (arg == "--translate") {
@@ -60,6 +73,12 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
                 whisper_print_usage(argc, argv, params);
                 exit(0);
             }
+        } else if (arg == "-otxt" || arg == "--output-txt") {
+            params.output_txt = true;
+        } else if (arg == "-ovtt" || arg == "--output-vtt") {
+            params.output_vtt = true;
+        } else if (arg == "-osrt" || arg == "--output-srt") {
+            params.output_srt = true;
         } else if (arg == "-ps" || arg == "--print_special") {
             params.print_special_tokens = true;
         } else if (arg == "-nt" || arg == "--no_timestamps") {
@@ -67,7 +86,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         } else if (arg == "-m" || arg == "--model") {
             params.model = argv[++i];
         } else if (arg == "-f" || arg == "--file") {
-            params.fname_inp = argv[++i];
+            params.fname_inp.push_back(argv[++i]);
         } else if (arg == "-h" || arg == "--help") {
             whisper_print_usage(argc, argv, params);
             exit(0);
@@ -83,19 +102,23 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
     fprintf(stderr, "\n");
-    fprintf(stderr, "usage: %s [options]\n", argv[0]);
+    fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
     fprintf(stderr, "  -h,       --help           show this help message and exit\n");
     fprintf(stderr, "  -s SEED,  --seed SEED      RNG seed (default: -1)\n");
     fprintf(stderr, "  -t N,     --threads N      number of threads to use during computation (default: %d)\n", params.n_threads);
+    fprintf(stderr, "  -o N,     --offset N       offset in milliseconds (default: %d)\n", params.offset_ms);
     fprintf(stderr, "  -v,       --verbose        verbose output\n");
     fprintf(stderr, "            --translate      translate from source language to english\n");
+    fprintf(stderr, "  -otxt,    --output-txt     output result in a text file\n");
+    fprintf(stderr, "  -ovtt,    --output-vtt     output result in a vtt file\n");
+    fprintf(stderr, "  -osrt,    --output-srt     output result in a srt file\n");
     fprintf(stderr, "  -ps,      --print_special  print special tokens\n");
     fprintf(stderr, "  -nt,      --no_timestamps  do not print timestamps\n");
     fprintf(stderr, "  -l LANG,  --language LANG  spoken language (default: %s)\n", params.language.c_str());
     fprintf(stderr, "  -m FNAME, --model FNAME    model path (default: %s)\n", params.model.c_str());
-    fprintf(stderr, "  -f FNAME, --file FNAME     input WAV file path (default: %s)\n", params.fname_inp.c_str());
+    fprintf(stderr, "  -f FNAME, --file FNAME     input WAV file path\n");
     fprintf(stderr, "\n");
 }
 
@@ -110,106 +133,189 @@ int main(int argc, char ** argv) {
         params.seed = time(NULL);
     }
 
+    if (params.fname_inp.empty()) {
+        fprintf(stderr, "error: no input files specified\n");
+        whisper_print_usage(argc, argv, params);
+        return 2;
+    }
+
     // whisper init
 
     struct whisper_context * ctx = whisper_init(params.model.c_str());
 
-    // WAV input
-    std::vector<float> pcmf32;
-    {
-        drwav wav;
-        if (!drwav_init_file(&wav, params.fname_inp.c_str(), NULL)) {
-            fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], params.fname_inp.c_str());
-            whisper_print_usage(argc, argv, {});
-            return 2;
-        }
+    for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
+        const auto fname_inp = params.fname_inp[f];
+
+        // WAV input
+        std::vector<float> pcmf32;
+        {
+            drwav wav;
+            if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) {
+                fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str());
+                whisper_print_usage(argc, argv, {});
+                return 3;
+            }
 
-        if (wav.channels != 1 && wav.channels != 2) {
-            fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], params.fname_inp.c_str());
-            return 3;
-        }
+            if (wav.channels != 1 && wav.channels != 2) {
+                fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
+                return 4;
+            }
 
-        if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
-            fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], params.fname_inp.c_str());
-            return 4;
-        }
+            if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
+                fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
+                return 5;
+            }
 
-        if (wav.bitsPerSample != 16) {
-            fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], params.fname_inp.c_str());
-            return 5;
-        }
+            if (wav.bitsPerSample != 16) {
+                fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
+                return 6;
+            }
 
-        int n = wav.totalPCMFrameCount;
+            int n = wav.totalPCMFrameCount;
 
-        std::vector<int16_t> pcm16;
-        pcm16.resize(n*wav.channels);
-        drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
-        drwav_uninit(&wav);
+            std::vector<int16_t> pcm16;
+            pcm16.resize(n*wav.channels);
+            drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
+            drwav_uninit(&wav);
 
-        // convert to mono, float
-        pcmf32.resize(n);
-        if (wav.channels == 1) {
-            for (int i = 0; i < n; i++) {
-                pcmf32[i] = float(pcm16[i])/32768.0f;
-            }
-        } else {
-            for (int i = 0; i < n; i++) {
-                pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
+            // convert to mono, float
+            pcmf32.resize(n);
+            if (wav.channels == 1) {
+                for (int i = 0; i < n; i++) {
+                    pcmf32[i] = float(pcm16[i])/32768.0f;
+                }
+            } else {
+                for (int i = 0; i < n; i++) {
+                    pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
+                }
             }
         }
-    }
 
-    // print some info about the processing
-    {
-        printf("\n");
-        if (!whisper_is_multilingual(ctx)) {
-            if (params.language != "en" || params.translate) {
-                params.language = "en";
-                params.translate = false;
-                printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
+        // print some info about the processing
+        {
+            fprintf(stderr, "\n");
+            if (!whisper_is_multilingual(ctx)) {
+                if (params.language != "en" || params.translate) {
+                    params.language = "en";
+                    params.translate = false;
+                    fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
+                }
             }
-        }
-        printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
-                __func__, int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads,
-                params.language.c_str(),
-                params.translate ? "translate" : "transcribe",
-                params.no_timestamps ? 0 : 1);
-        printf("\n");
-    }
+            fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
+                    __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads,
+                    params.language.c_str(),
+                    params.translate ? "translate" : "transcribe",
+                    params.no_timestamps ? 0 : 1);
 
-    // run the inference
-    {
-        whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY);
-
-        wparams.print_realtime       = true;
-        wparams.print_progress       = false;
-        wparams.print_timestamps     = !params.no_timestamps;
-        wparams.print_special_tokens = params.print_special_tokens;
-        wparams.translate            = params.translate;
-        wparams.language             = params.language.c_str();
-        wparams.n_threads            = params.n_threads;
-
-        if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
-            fprintf(stderr, "%s: failed to process audio\n", argv[0]);
-            return 6;
+            fprintf(stderr, "\n");
         }
 
-        // print result;
-        if (!wparams.print_realtime) {
+
+        // run the inference
+        {
+            whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY);
+
+            wparams.print_realtime       = true;
+            wparams.print_progress       = false;
+            wparams.print_timestamps     = !params.no_timestamps;
+            wparams.print_special_tokens = params.print_special_tokens;
+            wparams.translate            = params.translate;
+            wparams.language             = params.language.c_str();
+            wparams.n_threads            = params.n_threads;
+            wparams.offset_ms            = params.offset_ms;
+
+            if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
+                fprintf(stderr, "%s: failed to process audio\n", argv[0]);
+                return 7;
+            }
+
+            // print result
+            if (!wparams.print_realtime) {
+                printf("\n");
+
+                const int n_segments = whisper_full_n_segments(ctx);
+                for (int i = 0; i < n_segments; ++i) {
+                    const char * text = whisper_full_get_segment_text(ctx, i);
+
+                    if (params.no_timestamps) {
+                        printf("%s", text);
+                        fflush(stdout);
+                    } else {
+                        const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+                        const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+                        printf("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+                    }
+                }
+            }
+
             printf("\n");
 
-            const int n_segments = whisper_full_n_segments(ctx);
-            for (int i = 0; i < n_segments; ++i) {
-                const char * text = whisper_full_get_segment_text(ctx, i);
+            // output to text file
+            if (params.output_txt) {
+
+                const auto fname_txt = fname_inp + ".txt";
+                std::ofstream fout_txt(fname_txt);
+                if (!fout_txt.is_open()) {
+                    fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_txt.c_str());
+                    return 8;
+                }
+
+                fprintf(stderr, "%s: saving output to '%s.txt'\n", __func__, fname_inp.c_str());
+
+                const int n_segments = whisper_full_n_segments(ctx);
+                for (int i = 0; i < n_segments; ++i) {
+                    const char * text = whisper_full_get_segment_text(ctx, i);
+                    fout_txt << text;
+                }
+            }
+
+            // output to VTT file
+            if (params.output_vtt) {
+
+                const auto fname_vtt = fname_inp + ".vtt";
+                std::ofstream fout_vtt(fname_vtt);
+                if (!fout_vtt.is_open()) {
+                    fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_vtt.c_str());
+                    return 9;
+                }
+
+                fprintf(stderr, "%s: saving output to '%s.vtt'\n", __func__, fname_inp.c_str());
+
+                fout_vtt << "WEBVTT\n\n";
+
+                const int n_segments = whisper_full_n_segments(ctx);
+                for (int i = 0; i < n_segments; ++i) {
+                    const char * text = whisper_full_get_segment_text(ctx, i);
+                    const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+                    const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+                    fout_vtt << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
+                    fout_vtt << text << "\n\n";
+                }
+            }
+
+            // output to SRT file
+            if (params.output_srt) {
+
+                const auto fname_srt = fname_inp + ".srt";
+                std::ofstream fout_srt(fname_srt);
+                if (!fout_srt.is_open()) {
+                    fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_srt.c_str());
+                    return 10;
+                }
+
+                fprintf(stderr, "%s: saving output to '%s.srt'\n", __func__, fname_inp.c_str());
 
-                if (params.no_timestamps) {
-                    printf ("%s", text);
-                    fflush(stdout);
-                } else {
+                const int n_segments = whisper_full_n_segments(ctx);
+                for (int i = 0; i < n_segments; ++i) {
+                    const char * text = whisper_full_get_segment_text(ctx, i);
                     const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
                     const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
-                    printf ("[%s --> %s]  %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+                    fout_srt << i + 1 << "\n";
+                    fout_srt << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
+                    fout_srt << text << "\n\n";
                 }
             }
         }
index 46a4caa03238acd125e1d5cc51e09f7109bf24d3..81da46944f7d383560f23b69ddb837d29a4c9247 100644 (file)
@@ -405,6 +405,8 @@ struct whisper_context {
 
     std::vector<whisper_result>  result_cur;
     std::vector<whisper_segment> result_all;
+
+    std::vector<whisper_token> prompt_past;
 };
 
 // load the model from a ggml file
@@ -419,7 +421,7 @@ struct whisper_context {
 // see the convert-pt-to-ggml.py script for details
 //
 bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
-    printf("%s: loading model from '%s'\n", __func__, fname.c_str());
+    fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
 
     auto & model = wctx.model;
     auto & vocab = wctx.vocab;
@@ -478,18 +480,18 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
             model.type = e_model::MODEL_LARGE;
         }
 
-        printf("%s: n_vocab       = %d\n", __func__, hparams.n_vocab);
-        printf("%s: n_audio_ctx   = %d\n", __func__, hparams.n_audio_ctx);
-        printf("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
-        printf("%s: n_audio_head  = %d\n", __func__, hparams.n_audio_head);
-        printf("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
-        printf("%s: n_text_ctx    = %d\n", __func__, hparams.n_text_ctx);
-        printf("%s: n_text_state  = %d\n", __func__, hparams.n_text_state);
-        printf("%s: n_text_head   = %d\n", __func__, hparams.n_text_head);
-        printf("%s: n_text_layer  = %d\n", __func__, hparams.n_text_layer);
-        printf("%s: n_mels        = %d\n", __func__, hparams.n_mels);
-        printf("%s: f16           = %d\n", __func__, hparams.f16);
-        printf("%s: type          = %d\n", __func__, model.type);
+        fprintf(stderr, "%s: n_vocab       = %d\n", __func__, hparams.n_vocab);
+        fprintf(stderr, "%s: n_audio_ctx   = %d\n", __func__, hparams.n_audio_ctx);
+        fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
+        fprintf(stderr, "%s: n_audio_head  = %d\n", __func__, hparams.n_audio_head);
+        fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
+        fprintf(stderr, "%s: n_text_ctx    = %d\n", __func__, hparams.n_text_ctx);
+        fprintf(stderr, "%s: n_text_state  = %d\n", __func__, hparams.n_text_state);
+        fprintf(stderr, "%s: n_text_head   = %d\n", __func__, hparams.n_text_head);
+        fprintf(stderr, "%s: n_text_layer  = %d\n", __func__, hparams.n_text_layer);
+        fprintf(stderr, "%s: n_mels        = %d\n", __func__, hparams.n_mels);
+        fprintf(stderr, "%s: f16           = %d\n", __func__, hparams.f16);
+        fprintf(stderr, "%s: type          = %d\n", __func__, model.type);
 
         wctx.buf_model.resize(MEM_REQ_MODEL.at(model.type));
         wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
@@ -501,7 +503,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
                    wctx.buf_compute.size() +
                    wctx.buf_compute_layer.size();
 
-        printf("%s: mem_required  = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
+        fprintf(stderr, "%s: mem_required  = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
     }
 
     // load mel filters
@@ -551,7 +553,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
         }
 
         if (n_vocab < model.hparams.n_vocab) {
-            printf("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
+            fprintf(stderr, "%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
             for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
                 if (i > vocab.token_beg) {
                     word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
@@ -696,7 +698,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
 
         ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
 
-        printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+        fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
     }
 
     // create the ggml context
@@ -943,11 +945,12 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
             ggml_nbytes(model.memory_k)       + ggml_nbytes(model.memory_v) +
             ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
 
-        printf("%s: memory size = %8.2f MB \n", __func__, memory_size/1024.0/1024.0);
+        fprintf(stderr, "%s: memory size = %8.2f MB \n", __func__, memory_size/1024.0/1024.0);
     }
 
     // load weights
     {
+        int n_loaded = 0;
         size_t total_size = 0;
 
         while (true) {
@@ -1002,9 +1005,17 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
 
             //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
             total_size += ggml_nbytes(tensor);
+            n_loaded++;
         }
 
-        printf("%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
+        fprintf(stderr, "%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
+
+        if (n_loaded == 0) {
+            fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
+        } else if (n_loaded != (int) model.tensors.size()) {
+            fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), n_loaded);
+            return false;
+        }
     }
 
     fin.close();
@@ -1020,8 +1031,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
 //   - model:      the model
 //   - n_threads:  number of threads to use
 //   - mel_offset: offset in the mel spectrogram (i.e. audio offset)
-//   - mel_inp:    input mel spectrogram
-//   - features:   output encoded features
 //
 bool whisper_encode(
               whisper_context & wctx,
@@ -1405,10 +1414,9 @@ bool whisper_encode(
 //
 //   - model:      the model
 //   - n_threads:  number of threads to use
-//   - n_past:     prompt length
-//   - prompt:     text prompt
-//   - logits_out: output logits
-//   - probs_out:  output probabilities
+//   - tokens:     text prompt
+//   - n_tokens:   number of tokens in the prompt
+//   - n_past:     number of past tokens to prefix the prompt with
 //
 bool whisper_decode(
               whisper_context & wctx,
@@ -1773,8 +1781,6 @@ bool whisper_decode(
 }
 
 // the most basic sampling scheme - select the top token
-// TODO: beam search
-// TODO: temperature
 whisper_vocab::id whisper_sample_best(
         const whisper_vocab & vocab,
         const float * probs, bool need_timestamp) {
@@ -2236,13 +2242,13 @@ whisper_token whisper_token_transcribe() {
 void whisper_print_timings(struct whisper_context * ctx) {
     const int64_t t_end_us = ggml_time_us();
 
-    printf("\n\n");
-    printf("%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
-    printf("%s:      mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
-    printf("%s:   sample time = %8.2f ms\n", __func__, ctx->t_sample_us/1000.0f);
-    printf("%s:   encode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_encode_us/1000.0f, ctx->t_encode_us/1000.0f/ctx->model.hparams.n_audio_layer);
-    printf("%s:   decode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_decode_us/1000.0f, ctx->t_decode_us/1000.0f/ctx->model.hparams.n_text_layer);
-    printf("%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
+    fprintf(stderr, "\n");
+    fprintf(stderr, "%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
+    fprintf(stderr, "%s:      mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
+    fprintf(stderr, "%s:   sample time = %8.2f ms\n", __func__, ctx->t_sample_us/1000.0f);
+    fprintf(stderr, "%s:   encode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_encode_us/1000.0f, ctx->t_encode_us/1000.0f/ctx->model.hparams.n_audio_layer);
+    fprintf(stderr, "%s:   decode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_decode_us/1000.0f, ctx->t_decode_us/1000.0f/ctx->model.hparams.n_text_layer);
+    fprintf(stderr, "%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
 }
 
 ////////////////////////////////////////////////////////////////////////////
@@ -2256,8 +2262,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
                 result = (struct whisper_full_params) {
                     .strategy  = WHISPER_DECODE_GREEDY,
                     .n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()),
+                    .offset_ms = 0,
 
                     .translate            = false,
+                    .no_context           = false,
                     .print_special_tokens = false,
                     .print_progress       = true,
                     .print_realtime       = false,
@@ -2275,8 +2283,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
                 result = (struct whisper_full_params) {
                     .strategy  = WHISPER_DECODE_GREEDY,
                     .n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()),
+                    .offset_ms = 0,
 
                     .translate            = false,
+                    .no_context           = false,
                     .print_special_tokens = false,
                     .print_progress       = true,
                     .print_realtime       = false,
@@ -2295,6 +2305,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
 
     return result;
 }
+
 int whisper_full(
         struct whisper_context * ctx,
         struct whisper_full_params params,
@@ -2307,7 +2318,10 @@ int whisper_full(
     }
 
     // the accumulated text context so far
-    std::vector<whisper_token> prompt_past = { };
+    auto & prompt_past = ctx->prompt_past;
+    if (params.no_context) {
+        prompt_past.clear();
+    }
 
     // these tokens determine the task that will be performed
     std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
@@ -2329,13 +2343,13 @@ int whisper_full(
     int progress_step = 5;
 
     // main loop
-    int seek = 0;
+    int seek = params.offset_ms/10;
     while (true) {
         int progress_cur = (100*seek)/whisper_n_len(ctx);
         while (progress_cur >= progress_prev + progress_step) {
             progress_prev += progress_step;
             if (params.print_progress) {
-                printf("%s: progress = %3d%%\n", __func__, progress_prev);
+                fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
             }
         }
 
@@ -2463,7 +2477,7 @@ int whisper_full(
                         result_all.push_back({ t0, t1, text });
                     }
                     text = "";
-                    while (result_cur[i].id > whisper_token_beg(ctx) && i < (int) result_cur.size()) {
+                    while (i < (int) result_cur.size() && result_cur[i].id > whisper_token_beg(ctx)) {
                         i++;
                     }
                     i--;
index 2df5bdfb763378b608ff6086ed8187c5d1eb3b9b..f462370a33015e4e883f1a4517b6386b751358e9 100644 (file)
@@ -31,33 +31,81 @@ extern "C" {
     // C interface
     //
 
-    // TODO: documentation will come soon
+    //
+    // Basic usage:
+    //
+    //     #include "whisper.h"
+    //
+    //     ...
+    //
+    //     struct whisper_context * ctx = whisper_init("/path/to/ggml-base.en.bin");
+    //
+    //     if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
+    //         fprintf(stderr, "failed to process audio\n");
+    //         return 7;
+    //     }
+    //
+    //     const int n_segments = whisper_full_n_segments(ctx);
+    //     for (int i = 0; i < n_segments; ++i) {
+    //         const char * text = whisper_full_get_segment_text(ctx, i);
+    //         printf("%s", text);
+    //     }
+    //
+    //     whisper_free(ctx);
+    //
+    //     ...
+    //
+    // This is a demonstration of the most straightforward usage of the library.
+    // "pcmf32" contains the RAW audio data in 32-bit floating point format.
+    //
+    // The interface also allows for more fine-grained control over the computation, but it requires a deeper
+    // understanding of how the model works.
+    //
 
     struct whisper_context;
 
     typedef int whisper_token;
 
+    // Allocates all memory needed for the model and loads the model from the given file.
+    // Returns NULL on failure.
     WHISPER_API struct whisper_context * whisper_init(const char * path_model);
+
+    // Frees all memory allocated by the model.
     WHISPER_API void whisper_free(struct whisper_context * ctx);
 
+    // Convert RAW PCM audio to log mel spectrogram.
+    // The resulting spectrogram is stored inside the provided whisper context.
+    // Returns 0 on success
     WHISPER_API int whisper_pcm_to_mel(
             struct whisper_context * ctx,
             const float * samples,
             int n_samples,
             int n_threads);
 
+    // This can be used to set a custom log mel spectrogram inside the provided whisper context.
+    // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
     // n_mel must be 80
+    // Returns 0 on success
     WHISPER_API int whisper_set_mel(
             struct whisper_context * ctx,
             const float * data,
             int n_len,
             int n_mel);
 
+    // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
+    // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
+    // offset can be used to specify the offset of the first frame in the spectrogram.
+    // Returns 0 on success
     WHISPER_API int whisper_encode(
             struct whisper_context * ctx,
             int offset,
             int n_threads);
 
+    // Run the Whisper decoder to obtain the logits and probabilities for the next token.
+    // Make sure to call whisper_encode() first.
+    // tokens + n_tokens is the provided context for the decoder.
+    // n_past is the number of tokens to use from previous decoder calls.
+    // Returns 0 on success
     WHISPER_API int whisper_decode(
             struct whisper_context * ctx,
             const whisper_token * tokens,
@@ -65,20 +113,29 @@ extern "C" {
             int n_past,
             int n_threads);
 
+    // Token sampling methods.
+    // These are provided for convenience and can be used after each call to whisper_decode().
+    // You can also implement your own sampling method using the whisper_get_probs() function.
+    // whisper_sample_best() returns the token with the highest probability
+    // whisper_sample_timestamp() returns the most probable timestamp token
     WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp);
     WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
 
-    // return the id of the specified language, returns -1 if not found
+    // Return the id of the specified language, returns -1 if not found
     WHISPER_API int whisper_lang_id(const char * lang);
 
-    WHISPER_API int     whisper_n_len          (struct whisper_context * ctx); // mel length
-    WHISPER_API int     whisper_n_vocab        (struct whisper_context * ctx);
-    WHISPER_API int     whisper_n_text_ctx     (struct whisper_context * ctx);
-    WHISPER_API int     whisper_is_multilingual(struct whisper_context * ctx);
-    WHISPER_API float * whisper_get_probs      (struct whisper_context * ctx);
+    WHISPER_API int whisper_n_len          (struct whisper_context * ctx); // mel length
+    WHISPER_API int whisper_n_vocab        (struct whisper_context * ctx);
+    WHISPER_API int whisper_n_text_ctx     (struct whisper_context * ctx);
+    WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
+
+    // The probabilities for the next token
+    WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);
 
+    // Token Id -> String. Uses the vocabulary in the provided context
     WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
 
+    // Special tokens
     WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
@@ -86,24 +143,29 @@ extern "C" {
     WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
     WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
 
+    // Task tokens
     WHISPER_API whisper_token whisper_token_translate ();
     WHISPER_API whisper_token whisper_token_transcribe();
 
+    // Performance information
     WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
 
     ////////////////////////////////////////////////////////////////////////////
 
+    // Available decoding strategies
     enum whisper_decode_strategy {
-        WHISPER_DECODE_GREEDY,
-        WHISPER_DECODE_BEAM_SEARCH,
+        WHISPER_DECODE_GREEDY,      // Always select the most probable token
+        WHISPER_DECODE_BEAM_SEARCH, // TODO: not implemented yet!
     };
 
     struct whisper_full_params {
         enum whisper_decode_strategy strategy;
 
         int n_threads;
+        int offset_ms;
 
         bool translate;
+        bool no_context;
         bool print_special_tokens;
         bool print_progress;
         bool print_realtime;
@@ -126,18 +188,23 @@ extern "C" {
 
     WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy);
 
-    // full whisper run - encode + decode
+    // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
+    // Uses the specified decoding strategy to obtain the text.
     WHISPER_API int whisper_full(
             struct whisper_context * ctx,
             struct whisper_full_params params,
             const float * samples,
             int n_samples);
 
+    // Number of generated text segments.
+    // A segment can be a few words, a sentence, or even a paragraph.
     WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
 
+    // Get the start and end time of the specified segment.
     WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
     WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
 
+    // Get the text of the specified segment.
     WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
 
 #ifdef __cplusplus
index 58944893c519f7b376c168287b66aac6644a0ce2..a87e8dbc9f3b726d14b1a44393819a9e1f88bc3f 100644 (file)
@@ -181,9 +181,9 @@ int64_t ggml_cycles_per_ms(void) {
 //
 
 #if defined(__cpp_lib_hardware_interference_size)
-       const size_t CACHE_LINE_SIZE = hardware_destructive_interference_size;
+#define CACHE_LINE_SIZE hardware_destructive_interference_size
 #else
-       const size_t CACHE_LINE_SIZE = 64;
+#define CACHE_LINE_SIZE 64
 #endif
 
 const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);