]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
examples : small code cleanups (#322)
authorAndy Maloney <redacted>
Fri, 23 Dec 2022 18:18:51 +0000 (13:18 -0500)
committerGitHub <redacted>
Fri, 23 Dec 2022 18:18:51 +0000 (20:18 +0200)
- remove unnecessary initialization of string to ""
- use empty() instead of checking size()
- use emplace_back instead of push_back
- use nullptr instead of NULL
- remove unnecessary call to .data() on string
- use character overload of find_first_of() instead of passing a string

examples/command/command.cpp
examples/main/main.cpp
examples/stream/stream.cpp
examples/talk/gpt-2.cpp
examples/talk/talk.cpp

index 0bee82ffa650e64e152acb4224be5588373a4d4a..0ee33067f261757bf55660db253dc87b7ea88e37 100644 (file)
@@ -41,8 +41,8 @@ struct whisper_params {
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
-    std::string fname_out = "";
-    std::string commands  = "";
+    std::string fname_out;
+    std::string commands;
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -576,10 +576,10 @@ int main(int argc, char ** argv) {
     std::vector<std::string> allowed_commands;
     std::vector<std::vector<whisper_token>> allowed_tokens;
 
-    std::string k_prompt = "";
+    std::string k_prompt;
     std::vector<whisper_token> k_tokens;
 
-    if (params.commands != "") {
+    if (!params.commands.empty()) {
         fprintf(stderr, "\n");
         fprintf(stderr, "%s: guided mode\n", __func__);
 
@@ -808,7 +808,7 @@ int main(int argc, char ** argv) {
 
                 double psum = 0.0;
                 for (int i = 0; i < (int) allowed_commands.size(); ++i) {
-                    probs_id.push_back(std::make_pair(probs[allowed_tokens[i][0]], i));
+                    probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
                     for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
                         probs_id.back().first += probs[allowed_tokens[i][j]];
                     }
index b4252d45482a275a1d1cbae150743f003e3dd020..6e991b79392f6f1d3b46b8ef0206ea4f88a37530 100644 (file)
@@ -75,7 +75,7 @@ struct whisper_params {
     bool no_timestamps  = false;
 
     std::string language = "en";
-    std::string prompt   = "";
+    std::string prompt;
     std::string model    = "models/ggml-base.en.bin";
 
     std::vector<std::string> fname_inp = {};
@@ -118,7 +118,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-l"    || arg == "--language")       { params.language       = argv[++i]; }
         else if (                  arg == "--prompt")         { params.prompt         = argv[++i]; }
         else if (arg == "-m"    || arg == "--model")          { params.model          = argv[++i]; }
-        else if (arg == "-f"    || arg == "--file")           { params.fname_inp.push_back(argv[++i]); }
+        else if (arg == "-f"    || arg == "--file")           { params.fname_inp.emplace_back(argv[++i]); }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -206,7 +206,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
             const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
             const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
-            std::string speaker = "";
+            std::string speaker;
 
             if (params.diarize && pcmf32s.size() == 2) {
                 const int64_t n_samples = pcmf32s[0].size();
@@ -468,7 +468,7 @@ int main(int argc, char ** argv) {
     // initial prompt
     std::vector<whisper_token> prompt_tokens;
 
-    if (params.prompt.size() > 0) {
+    if (!params.prompt.empty()) {
         prompt_tokens.resize(1024);
         prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
 
@@ -505,14 +505,14 @@ int main(int argc, char ** argv) {
                     }
                 }
 
-                if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) {
+                if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
                     fprintf(stderr, "error: failed to open WAV file from stdin\n");
                     return 4;
                 }
 
                 fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
             }
-            else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
+            else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
                 fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
                 return 5;
             }
@@ -617,8 +617,8 @@ int main(int argc, char ** argv) {
 
             wparams.speed_up         = params.speed_up;
 
-            wparams.prompt_tokens    = prompt_tokens.size() == 0 ? nullptr : prompt_tokens.data();
-            wparams.prompt_n_tokens  = prompt_tokens.size() == 0 ? 0       : prompt_tokens.size();
+            wparams.prompt_tokens    = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
+            wparams.prompt_n_tokens  = prompt_tokens.empty() ? 0       : prompt_tokens.size();
 
             whisper_print_user_data user_data = { &params, &pcmf32s };
 
index 4d4abfd111a182d845f3205436d2b99815e11d2d..1752fff83cd6067e7fe0413235ec6f1351f0905a 100644 (file)
@@ -51,7 +51,7 @@ struct whisper_params {
 
     std::string language  = "en";
     std::string model     = "models/ggml-base.en.bin";
-    std::string fname_out = "";
+    std::string fname_out;
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
index 57ece9b91989d4402b0f03913f611a74b8f6cc5c..c50157fc8ce7204680a5d3e342cb1314d9982d0f 100644 (file)
@@ -40,7 +40,7 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
     // find the longest tokens that form the words:
     std::vector<gpt_vocab::id> tokens;
     for (const auto & word : words) {
-        if (word.size() == 0) continue;
+        if (word.empty()) continue;
 
         int i = 0;
         int n = word.size();
@@ -86,7 +86,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
     logits_id.reserve(n_logits);
 
     for (int i = 0; i < n_logits; i++) {
-        logits_id.push_back(std::make_pair(logits[i], i));
+        logits_id.emplace_back(logits[i], i);
     }
 
     // find the top K tokens
@@ -327,7 +327,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
     {
         struct ggml_init_params params;
         params.mem_size   = ctx_size;
-        params.mem_buffer = NULL;
+        params.mem_buffer = nullptr;
 
         model.ctx = ggml_init(params);
         if (!model.ctx) {
@@ -448,7 +448,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
             std::string name(length, 0);
             fin.read(&name[0], length);
 
-            if (model.tensors.find(name.data()) == model.tensors.end()) {
+            if (model.tensors.find(name) == model.tensors.end()) {
                 fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
                 return false;
             }
@@ -833,7 +833,7 @@ Me too.
 struct gpt2_context * gpt2_init(const char * path_model) {
     gpt2_context * ctx = new gpt2_context;
 
-    ctx->rng = std::mt19937(time(NULL));
+    ctx->rng = std::mt19937(time(nullptr));
 
     // load the model
     {
@@ -886,7 +886,7 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
 
     for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
         // predict
-        if (embd.size() > 0) {
+        if (!embd.empty()) {
             if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
                 printf("gpt-2: failed to generate text\n");
                 return "";
index e6fe5c8e27f905379369896f1614f8d9421996a3..ec57a95cd70119e37d1ef86ab0f9633b5da89275 100644 (file)
@@ -39,7 +39,7 @@ struct whisper_params {
     std::string model_wsp = "models/ggml-base.en.bin";
     std::string model_gpt = "models/ggml-gpt-2-117M.bin";
     std::string speak     = "./examples/talk/speak.sh";
-    std::string fname_out = "";
+    std::string fname_out;
 };
 
 void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -588,7 +588,7 @@ int main(int argc, char ** argv) {
 
                 audio.get(params.voice_ms, pcmf32_cur);
 
-                std::string text_heard = "";
+                std::string text_heard;
 
                 if (!force_speak) {
                     text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
@@ -610,7 +610,7 @@ int main(int argc, char ** argv) {
                 text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
 
                 // take first line
-                text_heard = text_heard.substr(0, text_heard.find_first_of("\n"));
+                text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
 
                 // remove leading and trailing whitespace
                 text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
@@ -640,18 +640,18 @@ int main(int argc, char ** argv) {
 
                     text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
                     text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
-                    text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
+                    text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of('\n'));
 
                     // remove first 2 lines of base prompt
                     if (n_iter > 4) {
                         {
-                            const size_t pos = prompt_base.find_first_of("\n");
+                            const size_t pos = prompt_base.find_first_of('\n');
                             if (pos != std::string::npos) {
                                 prompt_base = prompt_base.substr(pos + 1);
                             }
                         }
                         {
-                            const size_t pos = prompt_base.find_first_of("\n");
+                            const size_t pos = prompt_base.find_first_of('\n');
                             if (pos != std::string::npos) {
                                 prompt_base = prompt_base.substr(pos + 1);
                             }