]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add whisper_n_audio_ctx and check for invalid audio_ctx
authorGeorgi Gerganov <redacted>
Sat, 31 Dec 2022 07:55:33 +0000 (09:55 +0200)
committerGeorgi Gerganov <redacted>
Sat, 31 Dec 2022 07:57:19 +0000 (09:57 +0200)
closes #344

whisper.cpp
whisper.h

index d23e97feb8a6629fea2c86d738d362b42918191e..84c24900791a6427a581b86bd1741991ae230394 100644 (file)
@@ -2497,6 +2497,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
     return ctx->model.hparams.n_text_ctx;
 }
 
+int whisper_n_audio_ctx(struct whisper_context * ctx) {
+    return ctx->model.hparams.n_audio_ctx;
+}
+
 int whisper_is_multilingual(struct whisper_context * ctx) {
     return ctx->vocab.is_multilingual() ? 1 : 0;
 }
@@ -2822,7 +2826,11 @@ int whisper_full(
         std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
     }
 
-    // overwrite audio_ctx
+    // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
+    if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
+        fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
+        return -4;
+    }
     ctx->exp_n_audio_ctx = params.audio_ctx;
 
     // these tokens determine the task that will be performed
index 92c14da0edd3d1db2faa7854381a7ef0c4fee44d..e36b761ff6f9df3cfe30ab3ea30baf0362bb3ee8 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -177,6 +177,7 @@ extern "C" {
     WHISPER_API int whisper_n_len          (struct whisper_context * ctx); // mel length
     WHISPER_API int whisper_n_vocab        (struct whisper_context * ctx);
     WHISPER_API int whisper_n_text_ctx     (struct whisper_context * ctx);
+    WHISPER_API int whisper_n_audio_ctx    (struct whisper_context * ctx);
     WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
 
     // The probabilities for the next token