From: Georgi Gerganov Date: Sat, 31 Dec 2022 07:55:33 +0000 (+0200) Subject: whisper : add whisper_n_audio_ctx and check for invalid audio_ctx X-Git-Tag: upstream/1.7.4~1666 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=d97e6005e95f31ff812f72cd2cad3347080d1520;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp whisper : add whisper_n_audio_ctx and check for invalid audio_ctx closes #344 --- diff --git a/whisper.cpp b/whisper.cpp index d23e97fe..84c24900 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2497,6 +2497,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) { return ctx->model.hparams.n_text_ctx; } +int whisper_n_audio_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + int whisper_is_multilingual(struct whisper_context * ctx) { return ctx->vocab.is_multilingual() ? 1 : 0; } @@ -2822,7 +2826,11 @@ int whisper_full( std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); } - // overwrite audio_ctx + // overwrite audio_ctx, max allowed is hparams.n_audio_ctx + if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { + fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + return -4; + } ctx->exp_n_audio_ctx = params.audio_ctx; // these tokens determine the task that will be performed diff --git a/whisper.h b/whisper.h index 92c14da0..e36b761f 100644 --- a/whisper.h +++ b/whisper.h @@ -177,6 +177,7 @@ extern "C" { WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); // The probabilities for the next token