]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
server : implement "verbose_json" format with token details (#1781)
authorRyan Hitchman <redacted>
Thu, 18 Jan 2024 20:58:42 +0000 (13:58 -0700)
committerGitHub <redacted>
Thu, 18 Jan 2024 20:58:42 +0000 (22:58 +0200)
* examples/server: implement "verbose_json" format with token details.

This is intended to mirror the format of openai's Python
whisper.transcribe() return values.

* server: don't write WAV to a temporary file if not converting

* server: use std::lock_guard instead of manual lock/unlock

examples/common.cpp
examples/common.h
examples/server/server.cpp

index 603c655a184c745c0332351d729c2530afbe5772..8404e00e09e6e36383a2b4234c842ed39bf2b1ab 100644 (file)
@@ -639,6 +639,12 @@ bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector
 
         fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
     }
+    else if (fname.size() > 256 || fname.size() > 40 && fname.substr(0, 4) == "RIFF" && fname.substr(8, 4) == "WAVE") {
+        if (drwav_init_memory(&wav, fname.c_str(), fname.size(), nullptr) == false) {
+            fprintf(stderr, "error: failed to open WAV file from fname buffer\n");
+            return false;
+        }
+    }
     else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
         fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
         return false;
index 54f0b00d0ef41fd4e65c969663536c85bcb7680c..aebeb0cd4f57aa5c3177fe6bdb14c3779e366bff 100644 (file)
@@ -136,6 +136,7 @@ gpt_vocab::id gpt_sample_top_k_top_p_repeat(
 //
 
 // Read WAV audio file and store the PCM data into pcmf32
+// fname can be a buffer of WAV data instead of a filename
 // The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
 // If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
 bool read_wav(
index 8b6e4695259dabfe3fd426fcbb1a24ebf5f2f286..7de31859615caa8de47d8244f74f4f14a799db88 100644 (file)
@@ -18,7 +18,7 @@
 #endif
 
 using namespace httplib;
-using json = nlohmann::json;
+using json = nlohmann::ordered_json;
 
 namespace {
 
@@ -556,7 +556,7 @@ int main(int argc, char ** argv) {
 
     svr.Post(sparams.request_path + "/inference", [&](const Request &req, Response &res){
         // acquire whisper model mutex lock
-        whisper_mutex.lock();
+        std::lock_guard<std::mutex> lock(whisper_mutex);
 
         // first check user requested fields of the request
         if (!req.has_file("file"))
@@ -564,7 +564,6 @@ int main(int argc, char ** argv) {
             fprintf(stderr, "error: no 'file' field in the request\n");
             const std::string error_resp = "{\"error\":\"no 'file' field in the request\"}";
             res.set_content(error_resp, "application/json");
-            whisper_mutex.unlock();
             return;
         }
         auto audio_file = req.get_file_value("file");
@@ -579,35 +578,42 @@ int main(int argc, char ** argv) {
         std::vector<float> pcmf32;               // mono-channel F32 PCM
         std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
 
-        // write to temporary file
-        const std::string temp_filename = "whisper_server_temp_file.wav";
-        std::ofstream temp_file{temp_filename, std::ios::binary};
-        temp_file << audio_file.content;
-        temp_file.close();
-
-        // if file is not wav, convert to wav
-
         if (sparams.ffmpeg_converter) {
+            // if file is not wav, convert to wav
+            // write to temporary file
+            const std::string temp_filename = "whisper_server_temp_file.wav";
+            std::ofstream temp_file{temp_filename, std::ios::binary};
+            temp_file << audio_file.content;
+            temp_file.close();
+
             std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}";
             const bool is_converted = convert_to_wav(temp_filename, error_resp);
             if (!is_converted) {
                 res.set_content(error_resp, "application/json");
-                whisper_mutex.unlock();
                 return;
             }
-        }
 
-        // read wav content into pcmf32
-        if (!::read_wav(temp_filename, pcmf32, pcmf32s, params.diarize)) {
-            fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str());
-            const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
-            res.set_content(error_resp, "application/json");
+            // read wav content into pcmf32
+            if (!::read_wav(temp_filename, pcmf32, pcmf32s, params.diarize))
+            {
+                fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str());
+                const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
+                res.set_content(error_resp, "application/json");
+                std::remove(temp_filename.c_str());
+                return;
+            }
+            // remove temp file
             std::remove(temp_filename.c_str());
-            whisper_mutex.unlock();
-            return;
+        } else {
+            if (!::read_wav(audio_file.content, pcmf32, pcmf32s, params.diarize))
+            {
+                fprintf(stderr, "error: failed to read WAV file\n");
+                const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
+                res.set_content(error_resp, "application/json");
+                return;
+            }
         }
-        // remove temp file
-        std::remove(temp_filename.c_str());
+
 
         printf("Successfully loaded %s\n", filename.c_str());
 
@@ -681,6 +687,7 @@ int main(int argc, char ** argv) {
             wparams.logprob_thold    = params.logprob_thold;
 
             wparams.no_timestamps    = params.no_timestamps;
+            wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
 
             whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
 
@@ -724,7 +731,6 @@ int main(int argc, char ** argv) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
                 const std::string error_resp = "{\"error\":\"failed to process audio\"}";
                 res.set_content(error_resp, "application/json");
-                whisper_mutex.unlock();
                 return;
             }
         }
@@ -778,6 +784,43 @@ int main(int argc, char ** argv) {
                 ss << speaker << text << "\n\n";
             }
             res.set_content(ss.str(), "text/vtt");
+        } else if (params.response_format == vjson_format) {
+            /* try to match openai/whisper's Python format */
+            std::string results = output_str(ctx, params, pcmf32s);
+            json jres = json{{"text", results}};
+            const int n_segments = whisper_full_n_segments(ctx);
+            for (int i = 0; i < n_segments; ++i)
+            {
+                json segment = json{
+                    {"id", i},
+                    {"text", whisper_full_get_segment_text(ctx, i)},
+                };
+
+                if (!params.no_timestamps) {
+                    segment["start"] = whisper_full_get_segment_t0(ctx, i) * 0.01;
+                    segment["end"] = whisper_full_get_segment_t1(ctx, i) * 0.01;
+                }
+
+                const int n_tokens = whisper_full_n_tokens(ctx, i);
+                for (int j = 0; j < n_tokens; ++j) {
+                    whisper_token_data token = whisper_full_get_token_data(ctx, i, j);
+                    if (token.id >= whisper_token_eot(ctx)) {
+                        continue;
+                    }
+
+                    segment["tokens"].push_back(token.id);
+                    json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}};
+                    if (!params.no_timestamps) {
+                        word["start"] = token.t0 * 0.01;
+                        word["end"] = token.t1 * 0.01;
+                    }
+                    word["probability"] = token.p;
+                    segment["words"].push_back(word);
+                }
+                jres["segments"].push_back(segment);
+            }
+            res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace),
+                            "application/json");
         }
         // TODO add more output formats
         else
@@ -792,18 +835,14 @@ int main(int argc, char ** argv) {
 
         // reset params to thier defaults
         params = default_params;
-
-        // return whisper model mutex lock
-        whisper_mutex.unlock();
     });
     svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
-        whisper_mutex.lock();
+        std::lock_guard<std::mutex> lock(whisper_mutex);
         if (!req.has_file("model"))
         {
             fprintf(stderr, "error: no 'model' field in the request\n");
             const std::string error_resp = "{\"error\":\"no 'model' field in the request\"}";
             res.set_content(error_resp, "application/json");
-            whisper_mutex.unlock();
             return;
         }
         std::string model = req.get_file_value("model").content;
@@ -812,7 +851,6 @@ int main(int argc, char ** argv) {
             fprintf(stderr, "error: 'model': %s not found!\n", model.c_str());
             const std::string error_resp = "{\"error\":\"model not found!\"}";
             res.set_content(error_resp, "application/json");
-            whisper_mutex.unlock();
             return;
         }
 
@@ -835,7 +873,6 @@ int main(int argc, char ** argv) {
         res.set_content(success, "application/text");
 
         // check if the model is in the file system
-        whisper_mutex.unlock();
     });
 
     svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {