fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
- fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
+ fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
return 2;
}
- if (whisper_lang_id(params.language.c_str()) == -1) {
+ if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
struct ggml_init_params params;
params.mem_size = wctx.buf_compute.size();
- params.mem_buffer = wctx.buf_compute.data();
+ params.mem_buffer = wctx.buf_compute.data();
struct ggml_context * ctx0 = ggml_init(params);
return res.size();
}
+int whisper_lang_max_id() {
+ auto max_id = 0;
+ for (const auto & kv : g_lang) {
+ max_id = std::max(max_id, kv.second.first);
+ }
+
+ return max_id;
+}
+
int whisper_lang_id(const char * lang) {
if (!g_lang.count(lang)) {
+ for (const auto & kv : g_lang) {
+ if (kv.second.second == lang) {
+ return kv.second.first;
+ }
+ }
+
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
return -1;
}
return g_lang.at(lang).first;
}
+const char * whisper_lang_str(int id) {
+ for (const auto & kv : g_lang) {
+ if (kv.second.first == id) {
+ return kv.first.c_str();
+ }
+ }
+
+ fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
+ return NULL;
+}
+
+int whisper_lang_auto_detect(
+ struct whisper_context * ctx,
+ int offset_ms,
+ int n_threads,
+ float * lang_probs) {
+ const int seek = offset_ms/10;
+
+ if (seek < 0) {
+ fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
+ return -1;
+ }
+
+ if (seek >= ctx->mel.n_len) {
+ fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
+ return -2;
+ }
+
+ // run the encoder
+ if (whisper_encode(ctx, seek, n_threads) != 0) {
+ fprintf(stderr, "%s: failed to encode\n", __func__);
+ return -6;
+ }
+
+ const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
+
+ if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
+ fprintf(stderr, "%s: failed to decode\n", __func__);
+ return -7;
+ }
+
+ std::vector<std::pair<float, int>> probs_id;
+ for (const auto kv : g_lang) {
+ const auto token_lang = whisper_token_lang(ctx, kv.second.first);
+ probs_id.push_back({ ctx->probs[token_lang], kv.second.first });
+ }
+
+ // sort descending
+ {
+ using pair_type = decltype(probs_id)::value_type;
+ std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
+ return a.first > b.first;
+ });
+ }
+
+ // softmax
+ {
+ float sum = 0;
+ for (const auto & kv : probs_id) {
+ sum += exp(kv.first);
+ }
+
+ for (auto & kv : probs_id) {
+ kv.first = exp(kv.first) / sum;
+ }
+ }
+
+ {
+ for (int i = 0; i < probs_id.size(); i++) {
+ if (lang_probs) {
+ lang_probs[probs_id[i].second] = probs_id[i].first;
+ }
+
+ //printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
+ }
+ }
+
+ return probs_id[0].second;
+}
+
int whisper_n_len(struct whisper_context * ctx) {
return ctx->mel.n_len;
}
return ctx->vocab.token_beg;
}
+whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
+ return whisper_token_sot(ctx) + 1 + lang_id;
+}
+
whisper_token whisper_token_translate(void) {
return whisper_vocab::token_translate;
}
} else {
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
- return -1;
+ return -2;
}
}
+ // auto-detect language if not specified
+ if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
+ std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
+
+ const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
+ if (lang_id < 0) {
+ fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
+ return -3;
+ }
+
+ params.language = whisper_lang_str(lang_id);
+
+ fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
+ }
+
if (params.token_timestamps) {
ctx->t_beg = 0;
ctx->t_last = 0;
// these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) {
- prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language));
+ const int lang_id = whisper_lang_id(params.language);
+ prompt_init.push_back(whisper_token_lang(ctx, lang_id));
if (params.translate) {
prompt_init.push_back(whisper_token_translate());
} else {
// encode audio features starting at offset seek
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to encode\n", __func__);
- return 7;
+ return -4;
}
int n_past = 0;
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__);
- return 8;
+ return -5;
}
n_past += prompt.size();
whisper_token * tokens,
int n_max_tokens);
+ // Largest language id (i.e. number of available languages - 1)
+ WHISPER_API int whisper_lang_max_id();
+
// Return the id of the specified language, returns -1 if not found
+ // Examples:
+ // "de" -> 2
+ // "german" -> 2
WHISPER_API int whisper_lang_id(const char * lang);
+ // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
+ WHISPER_API const char * whisper_lang_str(int id);
+
+ // Use mel data at offset_ms to try and auto-detect the spoken language
+ // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
+ // Returns the top language id or negative on failure
+ // If not null, fills the lang_probs array with the probabilities of all languages
+ // The array must be whispe_lang_max_id() + 1 in size
+ // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
+ WHISPER_API int whisper_lang_auto_detect(
+ struct whisper_context * ctx,
+ int offset_ms,
+ int n_threads,
+ float * lang_probs);
+
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
+ WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
// Task tokens
WHISPER_API whisper_token whisper_token_translate (void);
const whisper_token * prompt_tokens;
int prompt_n_tokens;
+ // for auto-detection, set to nullptr, "" or "auto"
const char * language;
struct {