]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add support for new distilled Whisper models (#1424)
authorGeorgi Gerganov <redacted>
Sun, 5 Nov 2023 17:43:45 +0000 (19:43 +0200)
committerGitHub <redacted>
Sun, 5 Nov 2023 17:43:45 +0000 (19:43 +0200)
* whisper : add support for new distilled Whisper models

* whisper : print log when using distilled models

whisper.cpp

index 17ef4d9e8abbeb1a7b4eb8e5dc2ced88463aeaf6..3e36d362054e4ae7ab3eefe9a50d3c5c380bb007 100644 (file)
@@ -3940,6 +3940,7 @@ static void whisper_process_logits(
         // suppress task tokens
         logits[vocab.token_translate]  = -INFINITY;
         logits[vocab.token_transcribe] = -INFINITY;
+        logits[vocab.token_prev]       = -INFINITY;
 
         if (params.logits_filter_callback) {
             params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
@@ -4558,6 +4559,7 @@ int whisper_full_with_state(
 
     // these tokens determine the task that will be performed
     std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
+
     if (whisper_is_multilingual(ctx)) {
         const int lang_id = whisper_lang_id(params.language);
         state->lang_id = lang_id;
@@ -4569,6 +4571,17 @@ int whisper_full_with_state(
         }
     }
 
+    {
+        const bool is_distil = ctx->model.hparams.n_text_layer == 2;
+
+        // distilled models require the "no_timestamps" token
+        // TODO: add input parameter (#1229)
+        if (is_distil) {
+            log("%s: using distilled model - forcing no_timestamps\n", __func__);
+            prompt_init.push_back(whisper_token_not(ctx));
+        }
+    }
+
     int seek = seek_start;
 
     std::vector<whisper_token> prompt;