]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
server : add more parameters to server api (#1754)
authorGeorge Hindle <redacted>
Fri, 12 Jan 2024 11:42:52 +0000 (11:42 +0000)
committerGitHub <redacted>
Fri, 12 Jan 2024 11:42:52 +0000 (13:42 +0200)
* feat(server): add more parameters to server api

* fix(server): reset params to original parsed values for each request

examples/server/server.cpp

index 6f3ca6be8f4107a3f7cc0839962b4b68ced8364b..8b6e4695259dabfe3fd426fcbb1a24ebf5f2f286 100644 (file)
@@ -397,6 +397,13 @@ std::string output_str(struct whisper_context * ctx, const whisper_params & para
     return result.str();
 }
 
+bool parse_str_to_bool(const std::string & s) {
+    if (s == "true" || s == "1" || s == "yes" || s == "y") {
+        return true;
+    }
+    return false;
+}
+
 void get_req_parameters(const Request & req, whisper_params & params)
 {
     if (req.has_file("offset_t"))
@@ -415,6 +422,62 @@ void get_req_parameters(const Request & req, whisper_params & params)
     {
         params.max_context = std::stoi(req.get_file_value("max_context").content);
     }
+    if (req.has_file("max_len"))
+    {
+        params.max_len = std::stoi(req.get_file_value("max_len").content);
+    }
+    if (req.has_file("best_of"))
+    {
+        params.best_of = std::stoi(req.get_file_value("best_of").content);
+    }
+    if (req.has_file("beam_size"))
+    {
+        params.beam_size = std::stoi(req.get_file_value("beam_size").content);
+    }
+    if (req.has_file("word_thold"))
+    {
+        params.word_thold = std::stof(req.get_file_value("word_thold").content);
+    }
+    if (req.has_file("entropy_thold"))
+    {
+        params.entropy_thold = std::stof(req.get_file_value("entropy_thold").content);
+    }
+    if (req.has_file("logprob_thold"))
+    {
+        params.logprob_thold = std::stof(req.get_file_value("logprob_thold").content);
+    }
+    if (req.has_file("debug_mode"))
+    {
+        params.debug_mode = parse_str_to_bool(req.get_file_value("debug_mode").content);
+    }
+    if (req.has_file("translate"))
+    {
+        params.translate = parse_str_to_bool(req.get_file_value("translate").content);
+    }
+    if (req.has_file("diarize"))
+    {
+        params.diarize = parse_str_to_bool(req.get_file_value("diarize").content);
+    }
+    if (req.has_file("tinydiarize"))
+    {
+        params.tinydiarize = parse_str_to_bool(req.get_file_value("tinydiarize").content);
+    }
+    if (req.has_file("split_on_word"))
+    {
+        params.split_on_word = parse_str_to_bool(req.get_file_value("split_on_word").content);
+    }
+    if (req.has_file("no_timestamps"))
+    {
+        params.no_timestamps = parse_str_to_bool(req.get_file_value("no_timestamps").content);
+    }
+    if (req.has_file("language"))
+    {
+        params.language = req.get_file_value("language").content;
+    }
+    if (req.has_file("detect_language"))
+    {
+        params.detect_language = parse_str_to_bool(req.get_file_value("detect_language").content);
+    }
     if (req.has_file("prompt"))
     {
         params.prompt = req.get_file_value("prompt").content;
@@ -482,6 +545,9 @@ int main(int argc, char ** argv) {
 
     std::string const default_content = "<html>hello</html>";
 
+    // store default params so we can reset after each inference request
+    whisper_params default_params = params;
+
     // this is only called if no index.html is found in the public --path
     svr.Get(sparams.request_path + "/", [&default_content](const Request &, Response &res){
         res.set_content(default_content, "text/html");
@@ -724,6 +790,9 @@ int main(int argc, char ** argv) {
                             "application/json");
         }
 
+        // reset params to thier defaults
+        params = default_params;
+
         // return whisper model mutex lock
         whisper_mutex.unlock();
     });