]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : add n_gpu_layers parameter (#1475)
authorrlapray <redacted>
Mon, 13 Nov 2023 08:04:16 +0000 (09:04 +0100)
committerGitHub <redacted>
Mon, 13 Nov 2023 08:04:16 +0000 (10:04 +0200)
examples/talk-llama/talk-llama.cpp

index af971cabdd4ef2ee25656ccc89c7fee575ab3be5..0167b833e447924931e5c2de83de216a0539783f 100644 (file)
@@ -53,6 +53,7 @@ struct whisper_params {
     int32_t capture_id = -1;
     int32_t max_tokens = 32;
     int32_t audio_ctx  = 0;
+    int32_t n_gpu_layers = 0;
 
     float vad_thold  = 0.6f;
     float freq_thold = 100.0f;
@@ -90,6 +91,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-c"   || arg == "--capture")        { params.capture_id     = std::stoi(argv[++i]); }
         else if (arg == "-mt"  || arg == "--max-tokens")     { params.max_tokens     = std::stoi(argv[++i]); }
         else if (arg == "-ac"  || arg == "--audio-ctx")      { params.audio_ctx      = std::stoi(argv[++i]); }
+        else if (arg == "-ngl" || arg == "--n-gpu-layers")   { params.n_gpu_layers   = std::stoi(argv[++i]); }
         else if (arg == "-vth" || arg == "--vad-thold")      { params.vad_thold      = std::stof(argv[++i]); }
         else if (arg == "-fth" || arg == "--freq-thold")     { params.freq_thold     = std::stof(argv[++i]); }
         else if (arg == "-su"  || arg == "--speed-up")       { params.speed_up       = true; }
@@ -134,6 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -c ID,    --capture ID     [%-7d] capture device ID\n",                           params.capture_id);
     fprintf(stderr, "  -mt N,    --max-tokens N   [%-7d] maximum number of tokens per audio chunk\n",    params.max_tokens);
     fprintf(stderr, "  -ac N,    --audio-ctx N    [%-7d] audio context size (0 - all)\n",                params.audio_ctx);
+    fprintf(stderr, "  -ngl N,   --n-gpu-layers N [%-7s] number of layers to store in VRAM\n",           params.n_gpu_layers);
     fprintf(stderr, "  -vth N,   --vad-thold N    [%-7.2f] voice activity detection threshold\n",        params.vad_thold);
     fprintf(stderr, "  -fth N,   --freq-thold N   [%-7.2f] high-pass frequency cutoff\n",                params.freq_thold);
     fprintf(stderr, "  -su,      --speed-up       [%-7s] speed up audio by x2 (reduced accuracy)\n",     params.speed_up ? "true" : "false");
@@ -268,6 +271,8 @@ int main(int argc, char ** argv) {
     auto lmparams = llama_model_default_params();
     if (!params.use_gpu) {
         lmparams.n_gpu_layers = 0;
+    } else {
+        lmparams.n_gpu_layers = params.n_gpu_layers;
     }
 
     struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);