return std::string(buf);
}
+// helper function to replace substrings
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) {
pos = s.find(search, pos);
}
}
-// a cost-function that is high for text that takes longer to pronounce
-float voice_length(const std::string & text) {
- float res = 0.0f;
-
- for (size_t i = 0; i < text.size(); ++i) {
- if (text[i] == ' ') {
- res += 0.01f;
- } else if (text[i] == ',') {
- res += 2.00f;
- } else if (text[i] == '.') {
- res += 3.00f;
- } else if (text[i] == '!') {
- res += 3.00f;
- } else if (text[i] == '?') {
- res += 3.00f;
- } else if (text[i] >= '0' && text[i] <= '9') {
- res += 3.00f;
- } else {
- res += 1.00f;
- }
- }
-
- return res;
-}
-
// command-line parameters
struct whisper_params {
int32_t seed = -1; // RNG seed, not used currently
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
+ int32_t duration_ms = 0;
int32_t max_context = -1;
+ int32_t max_len = 0;
float word_thold = 0.01f;
params.offset_t_ms = std::stoi(argv[++i]);
} else if (arg == "-on" || arg == "--offset-n") {
params.offset_n = std::stoi(argv[++i]);
+ } else if (arg == "-d" || arg == "--duration") {
+ params.duration_ms = std::stoi(argv[++i]);
} else if (arg == "-mc" || arg == "--max-context") {
params.max_context = std::stoi(argv[++i]);
+ } else if (arg == "-ml" || arg == "--max-len") {
+ params.max_len = std::stoi(argv[++i]);
} else if (arg == "-wt" || arg == "--word-thold") {
params.word_thold = std::stof(argv[++i]);
} else if (arg == "-v" || arg == "--verbose") {
fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
+ fprintf(stderr, " -d N, --duration N duration of audio to process in milliseconds (default: %d)\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
+ fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
- fprintf(stderr, " -owts, --output-words output word-level timestamps to a text file\n");
+ fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n");
fprintf(stderr, " -ps, --print_special print special tokens\n");
fprintf(stderr, " -pc, --print_colors print colors\n");
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
fprintf(stderr, "\n");
}
-void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
+void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
const whisper_params & params = *(whisper_params *) user_data;
const int n_segments = whisper_full_n_segments(ctx);
- // print the last segment
- const int i = n_segments - 1;
- if (i == 0) {
+ // print the last n_new segments
+ const int s0 = n_segments - n_new;
+ if (s0 == 0) {
printf("\n");
}
- if (params.no_timestamps) {
- if (params.print_colors) {
- for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
- if (params.print_special_tokens == false) {
- const whisper_token id = whisper_full_get_token_id(ctx, i, j);
- if (id >= whisper_token_eot(ctx)) {
- continue;
+ for (int i = s0; i < n_segments; i++) {
+ if (params.no_timestamps) {
+ if (params.print_colors) {
+ for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
+ if (params.print_special_tokens == false) {
+ const whisper_token id = whisper_full_get_token_id(ctx, i, j);
+ if (id >= whisper_token_eot(ctx)) {
+ continue;
+ }
}
- }
- const char * text = whisper_full_get_token_text(ctx, i, j);
- const float p = whisper_full_get_token_p (ctx, i, j);
+ const char * text = whisper_full_get_token_text(ctx, i, j);
+ const float p = whisper_full_get_token_p (ctx, i, j);
- const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+ const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
- printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+ printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+ }
+ } else {
+ const char * text = whisper_full_get_segment_text(ctx, i);
+ printf("%s", text);
}
+ fflush(stdout);
} else {
- const char * text = whisper_full_get_segment_text(ctx, i);
- printf("%s", text);
- }
- fflush(stdout);
- } else {
- const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
- const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
- if (params.print_colors) {
- printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
- for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
- if (params.print_special_tokens == false) {
- const whisper_token id = whisper_full_get_token_id(ctx, i, j);
- if (id >= whisper_token_eot(ctx)) {
- continue;
+ if (params.print_colors) {
+ printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
+ for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
+ if (params.print_special_tokens == false) {
+ const whisper_token id = whisper_full_get_token_id(ctx, i, j);
+ if (id >= whisper_token_eot(ctx)) {
+ continue;
+ }
}
- }
- const char * text = whisper_full_get_token_text(ctx, i, j);
- const float p = whisper_full_get_token_p (ctx, i, j);
+ const char * text = whisper_full_get_token_text(ctx, i, j);
+ const float p = whisper_full_get_token_p (ctx, i, j);
- const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+ const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
- printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
- }
- printf("\n");
- } else {
- const char * text = whisper_full_get_segment_text(ctx, i);
+ printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+ }
+ printf("\n");
+ } else {
+ const char * text = whisper_full_get_segment_text(ctx, i);
- printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+ printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
+ }
}
}
}
return true;
}
-// word-level timestamps (experimental)
-// TODO: probably still has bugs, needs refactoring, etc..
-// TODO: auto threshold
-// TODO: extra pass to detect unused speech and assign to tokens
+// karaoke video generation
+// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
-bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector<float> & pcmf32) {
- if (params.output_wts) {
- std::vector<float> pcm_avg(pcmf32.size(), 0);
-
- // average the fabs of the signal
- {
- const int hw = 32;
-
- for (int i = 0; i < pcmf32.size(); i++) {
- float sum = 0;
- for (int j = -hw; j <= hw; j++) {
- if (i + j >= 0 && i + j < pcmf32.size()) {
- sum += fabs(pcmf32[i + j]);
- }
- }
- pcm_avg[i] = sum/(2*hw + 1);
- }
- }
-
- struct token_info {
- int64_t t0 = -1;
- int64_t t1 = -1;
-
- int64_t tt0 = -1;
- int64_t tt1 = -1;
+bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
+ std::ofstream fout(fname);
- whisper_token id;
- whisper_token tid;
+ fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
- float p = 0.0f;
- float pt = 0.0f;
- float ptsum = 0.0f;
+ // TODO: become parameter
+ static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
- std::string text;
- float vlen = 0.0f; // voice length of this token
- };
+ fout << "#!/bin/bash" << "\n";
+ fout << "\n";
- int64_t t_beg = 0;
- int64_t t_last = 0;
+ fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \"";
- whisper_token tid_last = 0;
+ for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
- std::ofstream fout(fname);
+ const int n = whisper_full_n_tokens(ctx, i);
- fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+ std::vector<whisper_token_data> tokens(n);
+ for (int j = 0; j < n; ++j) {
+ tokens[j] = whisper_full_get_token_data(ctx, i, j);
+ }
- fout << "!/bin/bash" << "\n";
- fout << "\n";
+ if (i > 0) {
+ fout << ",";
+ }
- fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE << ":rate=25:color=black -vf \"";
+ // background text
+ fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
bool is_first = true;
- for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
- const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
- const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
-
- const char *text = whisper_full_get_segment_text(ctx, i);
-
- const int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
- const int s1 = std::min((int) pcmf32.size(), (int) (t1*WHISPER_SAMPLE_RATE/100));
-
- const int n = whisper_full_n_tokens(ctx, i);
+ for (int j = 0; j < n; ++j) {
+ const auto & token = tokens[j];
- std::vector<token_info> tokens(n);
-
- if (n <= 1) {
+ if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
- for (int j = 0; j < n; ++j) {
- struct whisper_token_data token = whisper_full_get_token_data(ctx, i, j);
-
- if (j == 0) {
- if (token.id == whisper_token_beg(ctx)) {
- tokens[j ].t0 = t0;
- tokens[j ].t1 = t0;
- tokens[j + 1].t0 = t0;
-
- t_beg = t0;
- t_last = t0;
- tid_last = whisper_token_beg(ctx);
- } else {
- tokens[j ].t0 = t_last;
- }
- }
-
- const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
-
- tokens[j].id = token.id;
- tokens[j].tid = token.tid;
- tokens[j].p = token.p;
- tokens[j].pt = token.pt;
- tokens[j].ptsum = token.ptsum;
-
- tokens[j].text = whisper_token_to_str(ctx, token.id);
- //tokens[j].vlen = tokens[j].pt;
- tokens[j].vlen = voice_length(tokens[j].text);
-
- if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last && tt <= t1) {
- if (j > 0) {
- tokens[j - 1].t1 = tt;
- }
- tokens[j].t0 = tt;
- tid_last = token.tid;
- }
- }
-
- tokens[n - 2].t1 = t1;
- tokens[n - 1].t0 = t1;
- tokens[n - 1].t1 = t1;
-
- t_last = t1;
-
- int p0 = 0;
- int p1 = 0;
- while (true) {
- while (p1 < n && tokens[p1].t1 < 0) {
- p1++;
- }
-
- if (p1 >= n) {
- p1--;
- }
-
- if (p1 > p0) {
- double psum = 0.0;
- for (int j = p0; j <= p1; j++) {
- psum += tokens[j].vlen;
- }
-
- //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
-
- const double dt = tokens[p1].t1 - tokens[p0].t0;
+ std::string txt_bg;
+ std::string txt_fg; // highlight token
+ std::string txt_ul; // underline
- for (int j = p0 + 1; j <= p1; j++) {
- const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
- //const double ct = tokens[j - 1].t0 + (dt*(j - p0))/(p1 - p0 + 1);
- //const double ct = tokens[p0].t0 + (dt*(j - p0))/(p1 - p0 + 1);
+ txt_bg = "> ";
+ txt_fg = "> ";
+ txt_ul = "\\ \\ ";
- tokens[j - 1].t1 = ct;
- tokens[j ].t0 = ct;
- }
- }
-
- p1++;
- p0 = p1;
- if (p1 >= n) {
- break;
- }
- }
-
- for (int j = 0; j < n - 1; j++) {
- if (tokens[j].t1 < 0) {
- tokens[j + 1].t0 = tokens[j].t1;
- }
-
- if (j > 0) {
- if (tokens[j - 1].t1 > tokens[j].t0) {
- tokens[j].t0 = tokens[j - 1].t1;
- tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
- }
- }
-
- tokens[j].tt0 = tokens[j].t0;
- tokens[j].tt1 = tokens[j].t1;
- }
-
- // VAD
{
- const int hw = WHISPER_SAMPLE_RATE/8;
+ int ncnt = 0;
+ for (int k = 0; k < n; ++k) {
+ const auto & token2 = tokens[k];
- for (int j = 0; j < n; j++) {
- if (tokens[j].id >= whisper_token_eot(ctx)) {
+ if (tokens[k].id >= whisper_token_eot(ctx)) {
continue;
}
- const int64_t t0 = tokens[j].t0;
- const int64_t t1 = tokens[j].t1;
-
- int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
- int s1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100));
+ const std::string txt = whisper_token_to_str(ctx, token2.id);
- const int ss0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100) - hw);
- const int ss1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw);
-
- const int n = ss1 - ss0;
-
- float sum = 0.0f;
-
- for (int k = ss0; k < ss1; k++) {
- sum += pcm_avg[k];
- }
+ txt_bg += txt;
- const float thold = 0.5*sum/n;
-
- {
- int k = s0;
- if (pcm_avg[k] > thold && j > 0) {
- while (k > 0 && pcm_avg[k] > thold) {
- k--;
- }
- tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE);
- if (tokens[j].t0 < tokens[j - 1].t1) {
- tokens[j].t0 = tokens[j - 1].t1;
- } else {
- s0 = k;
- }
- } else {
- while (pcm_avg[k] < thold && k < s1) {
- k++;
- }
- s0 = k;
- tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE;
+ if (k == j) {
+ for (int l = 0; l < (int) txt.size(); ++l) {
+ txt_fg += txt[l];
+ txt_ul += "_";
}
- }
-
- {
- int k = s1;
- if (pcm_avg[k] > thold) {
- while (k < (int) pcmf32.size() - 1 && pcm_avg[k] > thold) {
- k++;
- }
- tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
- if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
- tokens[j].t1 = tokens[j + 1].t0;
- } else {
- s1 = k;
- }
- } else {
- while (pcm_avg[k] < thold && k > s0) {
- k--;
- }
- s1 = k;
- tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
+ txt_fg += "|";
+ } else {
+ for (int l = 0; l < (int) txt.size(); ++l) {
+ txt_fg += "\\ ";
+ txt_ul += "\\ ";
}
}
- }
- }
-
- const int t_expand = 0;
- for (int j = 0; j < n; j++) {
- if (j > 0) {
- tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
+ ncnt += txt.size();
}
- if (j < n - 1) {
- tokens[j].t1 = tokens[j].t1 + t_expand;
- }
- }
-
- for (int j = 0; j < n; ++j) {
- const auto & token = tokens[j];
- const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
- printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
- tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str());
-
- if (tokens[j].id >= whisper_token_eot(ctx)) {
- continue;
- }
-
- //printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id));
- //fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n";
+ ::replace_all(txt_bg, "'", "’");
+ ::replace_all(txt_bg, "\"", "\\\"");
+ ::replace_all(txt_fg, "'", "’");
+ ::replace_all(txt_fg, "\"", "\\\"");
}
- static const int line_wrap = 60;
- static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
-
- if (!is_first) {
- fout << ",";
- }
-
- // background text
- fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
-
- is_first = false;
-
- for (int j = 0; j < n; ++j) {
- const auto & token = tokens[j];
-
- if (tokens[j].id >= whisper_token_eot(ctx)) {
- continue;
- }
-
- std::string txt_bg;
- std::string txt_fg; // highlight token
- std::string txt_ul; // underline
-
- txt_bg = "> ";
- txt_fg = "> ";
- txt_ul = "\\ \\ ";
-
- {
- int ncnt = 0;
- for (int k = 0; k < n; ++k) {
- const auto & token2 = tokens[k];
-
- if (tokens[k].id >= whisper_token_eot(ctx)) {
- continue;
- }
-
- const std::string txt = whisper_token_to_str(ctx, token2.id);
-
- txt_bg += txt;
-
- if (k == j) {
- for (int l = 0; l < (int) txt.size(); ++l) {
- txt_fg += txt[l];
- txt_ul += "_";
- }
- txt_fg += "|";
- } else {
- for (int l = 0; l < (int) txt.size(); ++l) {
- txt_fg += "\\ ";
- txt_ul += "\\ ";
- }
- }
-
- ncnt += txt.size();
-
- if (ncnt > line_wrap) {
- if (k < j) {
- txt_bg = "> ";
- txt_fg = "> ";
- txt_ul = "\\ \\ ";
- ncnt = 0;
- } else {
- break;
- }
- }
- }
-
- ::replace_all(txt_bg, "'", "’");
- ::replace_all(txt_bg, "\"", "\\\"");
- ::replace_all(txt_fg, "'", "’");
- ::replace_all(txt_fg, "\"", "\\\"");
- }
-
+ if (is_first) {
// background text
- fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << token.tt0/100.0 << "," << token.tt1/100.0 << ")'";
+ fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'";
+ is_first = false;
+ }
- // foreground text
- fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
+ // foreground text
+ fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
- // underline
- fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
- }
+ // underline
+ fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
}
+ }
- fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n";
+ fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n";
- fout << "\n\n";
- fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n";
- fout << "\n";
- fout << "echo \" ffplay " << fname_inp << ".mp4\"\n";
- fout << "\n";
+ fout << "\n\n";
+ fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n";
+ fout << "\n";
+ fout << "echo \" ffplay " << fname_inp << ".mp4\"\n";
+ fout << "\n";
- fout.close();
+ fout.close();
- fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
- }
+ fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
return true;
}
std::vector<float> pcmf32;
{
drwav wav;
- if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) {
- fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str());
- whisper_print_usage(argc, argv, {});
+
+ if (fname_inp == "-") {
+ std::vector<uint8_t> wav_data;
+ {
+ uint8_t buf[1024];
+ while (true)
+ {
+ const size_t n = fread(buf, 1, sizeof(buf), stdin);
+ if (n == 0)
+ {
+ break;
+ }
+ wav_data.insert(wav_data.end(), buf, buf + n);
+ }
+ }
+
+ if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false)
+ {
+ fprintf(stderr, "error: failed to open WAV file from stdin\n");
+ return 4;
+ }
+ }
+ else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
+ fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 4;
}
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
+ wparams.duration_ms = params.duration_ms;
+
+ wparams.token_timestamps = params.output_wts || params.max_len > 0;
+ wparams.thold_pt = params.word_thold;
+ wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
// this callback is called on each new segment
if (!wparams.print_realtime) {
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_inp + ".wts";
- output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32);
+ output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
}
}
}
static const size_t MB = 1024*1024;
static const std::map<e_model, size_t> MEM_REQ_MODEL = {
- { MODEL_TINY, 86ull*MB },
- { MODEL_BASE, 165ull*MB },
- { MODEL_SMALL, 540ull*MB },
- { MODEL_MEDIUM, 1650ull*MB },
- { MODEL_LARGE, 3260ull*MB },
+ { MODEL_TINY, 74ull*MB },
+ { MODEL_BASE, 142ull*MB },
+ { MODEL_SMALL, 466ull*MB },
+ { MODEL_MEDIUM, 1464ull*MB },
+ { MODEL_LARGE, 2952ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
+ { MODEL_TINY, 12ull*MB },
+ { MODEL_BASE, 24ull*MB },
+ { MODEL_SMALL, 70ull*MB },
+ { MODEL_MEDIUM, 184ull*MB },
+ { MODEL_LARGE, 306ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
+
+ // [EXPERIMENTAL] token-level timestamps data
+ int64_t t_beg;
+ int64_t t_last;
+ whisper_token tid_last;
+ std::vector<float> energy; // PCM signal energy
};
// load the model from a ggml file
//
// see the convert-pt-to-ggml.py script for details
//
-bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
+static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
auto & model = wctx.model;
wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
- wctx.buf_memory.resize(std::max(MEM_REQ_MODEL.at(model.type), MEM_REQ_MODEL.at(model.type))); // TODO: TMP !!!
+ wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
}
}
- // create the ggml memory context
- {
- struct ggml_init_params params = {
- .mem_size = wctx.buf_memory.size(),
- .mem_buffer = wctx.buf_memory.data(),
- };
-
- model.ctx_mem = ggml_init(params);
- if (!model.ctx_mem) {
- fprintf(stderr, "%s: ggml_init() failed\n", __func__);
- return false;
- }
- }
-
// prepare memory for the weights
{
auto & ctx = model.ctx;
}
}
+ // create the ggml memory context
+ {
+ struct ggml_init_params params = {
+ .mem_size = wctx.buf_memory.size(),
+ .mem_buffer = wctx.buf_memory.data(),
+ };
+
+ model.ctx_mem = ggml_init(params);
+ if (!model.ctx_mem) {
+ fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+ return false;
+ }
+ }
+
// key + value memory
{
auto & ctx = model.ctx_mem;
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
//
-bool whisper_encode(
+static bool whisper_encode(
whisper_context & wctx,
const int n_threads,
const int mel_offset) {
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
-bool whisper_decode(
+static bool whisper_decode(
whisper_context & wctx,
const int n_threads,
const whisper_token * tokens,
}
// the most basic sampling scheme - select the top token
-whisper_token_data whisper_sample_best(
+static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab,
const float * probs) {
- whisper_token_data result;
+ whisper_token_data result = {
+ 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
+ };
int n_logits = vocab.id_to_token.size();
}
// samples only from the timestamps tokens
-whisper_vocab::id whisper_sample_timestamp(
+static whisper_vocab::id whisper_sample_timestamp(
const whisper_vocab & vocab,
const float * probs) {
int n_logits = vocab.id_to_token.size();
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
-void dft(const std::vector<float> & in, std::vector<float> & out) {
+static void dft(const std::vector<float> & in, std::vector<float> & out) {
int N = in.size();
out.resize(N*2);
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
-void fft(const std::vector<float> & in, std::vector<float> & out) {
+static void fft(const std::vector<float> & in, std::vector<float> & out) {
out.resize(in.size()*2);
int N = in.size();
}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
-bool log_mel_spectrogram(
+static bool log_mel_spectrogram(
const float * samples,
const int n_samples,
const int sample_rate,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
+ /*.duration_ms =*/ 0,
/*.translate =*/ false,
/*.no_context =*/ false,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
+ /*.token_timestamps =*/ false,
+ /*.thold_pt =*/ 0.01f,
+ /*.thold_ptsum =*/ 0.01f,
+ /*.max_len =*/ 0,
+
/*.language =*/ "en",
/*.greedy =*/ {
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
+ /*.duration_ms =*/ 0,
/*.translate =*/ false,
/*.no_context =*/ false,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
+ /*.token_timestamps =*/ false,
+ /*.thold_pt =*/ 0.01f,
+ /*.thold_ptsum =*/ 0.01f,
+ /*.max_len =*/ 0,
+
/*.language =*/ "en",
/*.greedy =*/ {
return result;
}
+// forward declarations
+static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
+static void whisper_exp_compute_token_level_timestamps(
+ struct whisper_context * ctx,
+ int i_segment,
+ float thold_pt,
+ float thold_ptsum);
+
+// wrap the last segment to max_len characters
+// returns the number of new segments
+static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
+ auto segment = ctx->result_all.back();
+
+ int res = 1;
+ int acc = 0;
+
+ std::string text;
+
+ for (int i = 0; i < (int) segment.tokens.size(); i++) {
+ const auto & token = segment.tokens[i];
+ if (token.id >= whisper_token_eot(ctx)) {
+ continue;
+ }
+
+ const auto txt = whisper_token_to_str(ctx, token.id);
+
+ const int cur = strlen(txt);
+
+ if (acc + cur > max_len && i > 0) {
+ // split here
+ ctx->result_all.back().text = std::move(text);
+ ctx->result_all.back().t1 = token.t0;
+ ctx->result_all.back().tokens.resize(i);
+
+ ctx->result_all.push_back({});
+ ctx->result_all.back().t0 = token.t0;
+ ctx->result_all.back().t1 = segment.t1;
+
+ // add tokens [i, end] to the new segment
+ ctx->result_all.back().tokens.insert(
+ ctx->result_all.back().tokens.end(),
+ segment.tokens.begin() + i,
+ segment.tokens.end());
+
+ acc = 0;
+ text = "";
+
+ segment = ctx->result_all.back();
+ i = -1;
+
+ res++;
+ } else {
+ acc += cur;
+ text += txt;
+ }
+ }
+
+ ctx->result_all.back().text = std::move(text);
+
+ return res;
+}
+
int whisper_full(
struct whisper_context * ctx,
struct whisper_full_params params,
return -1;
}
+ if (params.token_timestamps) {
+ ctx->t_beg = 0;
+ ctx->t_last = 0;
+ ctx->tid_last = 0;
+ ctx->energy = get_signal_energy(samples, n_samples, 32);
+ }
+
const int seek_start = params.offset_ms/10;
+ const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
// if length of spectrogram is less than 1s (100 samples), then return
// basically don't process anything that is less than 1s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
- if (whisper_n_len(ctx) < 100 + seek_start) {
+ if (seek_end < 100 + seek_start) {
return 0;
}
// main loop
int seek = seek_start;
while (true) {
- int progress_cur = (100*seek)/whisper_n_len(ctx);
+ const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
while (progress_cur >= progress_prev + progress_step) {
progress_prev += progress_step;
if (params.print_progress) {
}
}
- if (seek + 100 >= whisper_n_len(ctx)) {
+ if (seek + 100 >= seek_end) {
break;
}
// end of text token
if (token.id == whisper_token_eot(ctx)) {
if (result_len == 0) {
- if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
+ if (seek + seek_delta + 100 >= seek_end) {
result_len = i + 1;
} else {
// TODO: figure out how to resolve this
}
}
+ // shrink down to result_len
tokens_cur.resize(result_len);
for (const auto & r : tokens_cur) {
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
+
+ int n_new = 1;
+
+ if (params.token_timestamps) {
+ whisper_exp_compute_token_level_timestamps(
+ ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+
+ if (params.max_len > 0) {
+ n_new = whisper_wrap_segment(ctx, params.max_len);
+ }
+ }
if (params.new_segment_callback) {
- params.new_segment_callback(ctx, params.new_segment_callback_user_data);
+ params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
text = "";
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
+
+ int n_new = 1;
+
+ if (params.token_timestamps) {
+ whisper_exp_compute_token_level_timestamps(
+ ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+
+ if (params.max_len > 0) {
+ n_new = whisper_wrap_segment(ctx, params.max_len);
+ }
+ }
if (params.new_segment_callback) {
- params.new_segment_callback(ctx, params.new_segment_callback_user_data);
+ params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
}
// call the new_segment_callback for each segment
if (params.new_segment_callback) {
- params.new_segment_callback(ctx, params.new_segment_callback_user_data);
+ params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
}
}
return s.c_str();
}
+
+// =================================================================================================
+
+//
+// Experimental stuff below
+//
+// Not sure if these should be part of the library at all, because the quality of the results is not
+// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
+//
+
+// =================================================================================================
+
+//
+// token-level timestamps
+//
+
+static int timestamp_to_sample(int64_t t, int n_samples) {
+ return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
+}
+
+static int64_t sample_to_timestamp(int i_sample) {
+ return (100*i_sample)/WHISPER_SAMPLE_RATE;
+}
+
+// a cost-function / heuristic that is high for text that takes longer to pronounce
+// obviously, can be improved
+static float voice_length(const std::string & text) {
+ float res = 0.0f;
+
+ for (size_t i = 0; i < text.size(); ++i) {
+ if (text[i] == ' ') {
+ res += 0.01f;
+ } else if (text[i] == ',') {
+ res += 2.00f;
+ } else if (text[i] == '.') {
+ res += 3.00f;
+ } else if (text[i] == '!') {
+ res += 3.00f;
+ } else if (text[i] == '?') {
+ res += 3.00f;
+ } else if (text[i] >= '0' && text[i] <= '9') {
+ res += 3.00f;
+ } else {
+ res += 1.00f;
+ }
+ }
+
+ return res;
+}
+
+// average the fabs of the signal
+static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
+ const int hw = n_samples_per_half_window;
+
+ std::vector<float> result(n_samples);
+
+ for (int i = 0; i < n_samples; i++) {
+ float sum = 0;
+ for (int j = -hw; j <= hw; j++) {
+ if (i + j >= 0 && i + j < n_samples) {
+ sum += fabs(signal[i + j]);
+ }
+ }
+ result[i] = sum/(2*hw + 1);
+ }
+
+ return result;
+}
+
+static void whisper_exp_compute_token_level_timestamps(
+ struct whisper_context * ctx,
+ int i_segment,
+ float thold_pt,
+ float thold_ptsum) {
+ auto & segment = ctx->result_all[i_segment];
+ auto & tokens = segment.tokens;
+
+ const int n_samples = ctx->energy.size();
+
+ if (n_samples == 0) {
+ fprintf(stderr, "%s: no signal data available\n", __func__);
+ return;
+ }
+
+ const int64_t t0 = segment.t0;
+ const int64_t t1 = segment.t1;
+
+ const int s0 = timestamp_to_sample(t0, n_samples);
+ const int s1 = timestamp_to_sample(t1, n_samples);
+
+ const int n = tokens.size();
+
+ if (n == 0) {
+ return;
+ }
+
+ if (n == 1) {
+ tokens[0].t0 = t0;
+ tokens[0].t1 = t1;
+
+ return;
+ }
+
+ auto & t_beg = ctx->t_beg;
+ auto & t_last = ctx->t_last;
+ auto & tid_last = ctx->tid_last;
+
+ for (int j = 0; j < n; ++j) {
+ auto & token = tokens[j];
+
+ if (j == 0) {
+ if (token.id == whisper_token_beg(ctx)) {
+ tokens[j ].t0 = t0;
+ tokens[j ].t1 = t0;
+ tokens[j + 1].t0 = t0;
+
+ t_beg = t0;
+ t_last = t0;
+ tid_last = whisper_token_beg(ctx);
+ } else {
+ tokens[j ].t0 = t_last;
+ }
+ }
+
+ const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
+
+ tokens[j].id = token.id;
+ tokens[j].tid = token.tid;
+ tokens[j].p = token.p;
+ tokens[j].pt = token.pt;
+ tokens[j].ptsum = token.ptsum;
+
+ tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
+
+ if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
+ if (j > 0) {
+ tokens[j - 1].t1 = tt;
+ }
+ tokens[j].t0 = tt;
+ tid_last = token.tid;
+ }
+ }
+
+ tokens[n - 2].t1 = t1;
+ tokens[n - 1].t0 = t1;
+ tokens[n - 1].t1 = t1;
+
+ t_last = t1;
+
+ // find intervals of tokens with unknown timestamps
+ // fill the timestamps by proportionally splitting the interval based on the token voice lengths
+ {
+ int p0 = 0;
+ int p1 = 0;
+
+ while (true) {
+ while (p1 < n && tokens[p1].t1 < 0) {
+ p1++;
+ }
+
+ if (p1 >= n) {
+ p1--;
+ }
+
+ if (p1 > p0) {
+ double psum = 0.0;
+ for (int j = p0; j <= p1; j++) {
+ psum += tokens[j].vlen;
+ }
+
+ //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
+
+ const double dt = tokens[p1].t1 - tokens[p0].t0;
+
+ // split the time proportionally to the voice length
+ for (int j = p0 + 1; j <= p1; j++) {
+ const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
+
+ tokens[j - 1].t1 = ct;
+ tokens[j ].t0 = ct;
+ }
+ }
+
+ p1++;
+ p0 = p1;
+ if (p1 >= n) {
+ break;
+ }
+ }
+ }
+
+ // fix up (just in case)
+ for (int j = 0; j < n - 1; j++) {
+ if (tokens[j].t1 < 0) {
+ tokens[j + 1].t0 = tokens[j].t1;
+ }
+
+ if (j > 0) {
+ if (tokens[j - 1].t1 > tokens[j].t0) {
+ tokens[j].t0 = tokens[j - 1].t1;
+ tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
+ }
+ }
+ }
+
+ // VAD
+ // expand or contract tokens based on voice activity
+ {
+ const int hw = WHISPER_SAMPLE_RATE/8;
+
+ for (int j = 0; j < n; j++) {
+ if (tokens[j].id >= whisper_token_eot(ctx)) {
+ continue;
+ }
+
+ int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
+ int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
+
+ const int ss0 = std::max(s0 - hw, 0);
+ const int ss1 = std::min(s1 + hw, n_samples);
+
+ const int ns = ss1 - ss0;
+
+ float sum = 0.0f;
+
+ for (int k = ss0; k < ss1; k++) {
+ sum += ctx->energy[k];
+ }
+
+ const float thold = 0.5*sum/ns;
+
+ {
+ int k = s0;
+ if (ctx->energy[k] > thold && j > 0) {
+ while (k > 0 && ctx->energy[k] > thold) {
+ k--;
+ }
+ tokens[j].t0 = sample_to_timestamp(k);
+ if (tokens[j].t0 < tokens[j - 1].t1) {
+ tokens[j].t0 = tokens[j - 1].t1;
+ } else {
+ s0 = k;
+ }
+ } else {
+ while (ctx->energy[k] < thold && k < s1) {
+ k++;
+ }
+ s0 = k;
+ tokens[j].t0 = sample_to_timestamp(k);
+ }
+ }
+
+ {
+ int k = s1;
+ if (ctx->energy[k] > thold) {
+ while (k < n_samples - 1 && ctx->energy[k] > thold) {
+ k++;
+ }
+ tokens[j].t1 = sample_to_timestamp(k);
+ if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
+ tokens[j].t1 = tokens[j + 1].t0;
+ } else {
+ s1 = k;
+ }
+ } else {
+ while (ctx->energy[k] < thold && k > s0) {
+ k--;
+ }
+ s1 = k;
+ tokens[j].t1 = sample_to_timestamp(k);
+ }
+ }
+ }
+ }
+
+ // fixed token expand (optional)
+ //{
+ // const int t_expand = 0;
+
+ // for (int j = 0; j < n; j++) {
+ // if (j > 0) {
+ // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
+ // }
+ // if (j < n - 1) {
+ // tokens[j].t1 = tokens[j].t1 + t_expand;
+ // }
+ // }
+ //}
+
+ // debug info
+ //for (int j = 0; j < n; ++j) {
+ // const auto & token = tokens[j];
+ // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
+ // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
+ // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
+
+ // if (tokens[j].id >= whisper_token_eot(ctx)) {
+ // continue;
+ // }
+ //}
+}
typedef int whisper_token;
- struct whisper_token_data {
+ typedef struct whisper_token_data {
whisper_token id; // token id
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens
- };
+
+ // token-level timestamp data
+ // do not use if you haven't computed token-level timestamps
+ int64_t t0; // start time of the token
+ int64_t t1; // end time of the token
+
+ float vlen; // voice length of the token
+ } whisper_token_data;
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
// You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
- WHISPER_API struct whisper_token_data whisper_sample_best(struct whisper_context * ctx);
+ WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
// Return the id of the specified language, returns -1 if not found
// Text segment callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
- typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
+ typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
struct whisper_full_params {
enum whisper_sampling_strategy strategy;
int n_threads;
int n_max_text_ctx;
- int offset_ms;
+ int offset_ms; // start offset in ms
+ int duration_ms; // audio duration to process in ms
bool translate;
bool no_context;
bool print_realtime;
bool print_timestamps;
+ // [EXPERIMENTAL] token-level timestamps
+ bool token_timestamps; // enable token-level timestamps
+ float thold_pt; // timestamp token probability threshold (~0.01)
+ float thold_ptsum; // timestamp token sum probability threshold (~0.01)
+ int max_len; // max segment length in characters
+
const char * language;
struct {
// Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc.
- WHISPER_API struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
+ WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
#pragma once
+//
+// GGML Tensor Library
+//
+// This documentation is still a work in progress.
+// If you wish some specific topics to be covered, feel free to drop a comment:
+//
+// https://github.com/ggerganov/whisper.cpp/issues/40
+//
+// ## Overview
+//
+// This library implements:
+//
+// - a set of tensor operations
+// - automatic differentiation
+// - basic optimization algorithms
+//
+// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
+// but is not limited to, the following:
+//
+// - linear regression
+// - support vector machines
+// - neural networks
+//
+// The library allows the user to define a certain function using the available tensor operations. This function
+// definition is represented internally via a computation graph. Each tensor operation in the function definition
+// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
+// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
+// using one of the available optimization algorithms.
+//
+// For example, here we define the function: f(x) = a*x^2 + b
+//
+// {
+// struct ggml_init_params params = {
+// .mem_size = 16*1024*1024,
+// .mem_buffer = NULL,
+// };
+//
+// // memory allocation happens here
+// struct ggml_context * ctx = ggml_init(params);
+//
+// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+//
+// ggml_set_param(ctx, x); // x is an input variable
+//
+// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+// struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
+// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
+//
+// ...
+// }
+//
+// Notice that the function definition above does not involve any actual computation. The computation is performed only
+// when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
+//
+// {
+// ...
+//
+// struct ggml_cgraph gf = ggml_build_forward(f);
+//
+// // set the input variable and parameter values
+// ggml_set_f32(x, 2.0f);
+// ggml_set_f32(a, 3.0f);
+// ggml_set_f32(b, 4.0f);
+//
+// ggml_graph_compute(ctx0, &gf);
+//
+// printf("f = %f\n", ggml_get_f32_1d(f, 0));
+//
+// ...
+// }
+//
+// The actual computation is performed in the ggml_graph_compute() function.
+//
+// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
+// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
+// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
+// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
+// actually needed.
+//
+// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
+// differentiation and optimization algorithms.
+//
+// The described approach allows to define the function graph once and then compute its forward or backward graphs
+// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
+// the user can avoid the memory allocation overhead at runtime.
+//
+// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
+// citizens, but in theory the library can be extended to support FP8 and integer data types.
+//
+// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
+// and binary operations. Most of the available operations fall into one of these two categories. With time, it became
+// clear that the library needs to support more complex operations. The way to support these operations is not clear
+// yet, but a few examples are demonstrated in the following operations:
+//
+// - ggml_permute()
+// - ggml_conv_1d_1s()
+// - ggml_conv_1d_2s()
+//
+// For each tensor operator, the library implements a forward and backward computation function. The forward function
+// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
+// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
+// calculus class, or watch the following video:
+//
+// What is Automatic Differentiation?
+// https://www.youtube.com/watch?v=wG_nF1awSSY
+//
+//
+// ## Tensor data (struct ggml_tensor)
+//
+// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
+// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
+// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
+//
+// {
+// struct ggml_tensor * c = ggml_add(ctx, a, b);
+//
+// assert(c->src[0] == a);
+// assert(c->src[1] == b);
+// }
+//
+// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
+// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
+// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
+// permutation. All tensor operations have to take the stride into account and not assume that the tensor is
+// contiguous in memory.
+//
+// The data of the tensor is accessed via the "data" pointer. For example:
+//
+// {
+// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3);
+//
+// // a[1, 2] = 1.0f;
+// *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f;
+//
+// // a[2, 0] = 2.0f;
+// *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f;
+//
+// ...
+// }
+//
+// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
+//
+// ## The matrix multiplication operator (ggml_mul_mat)
+//
+// TODO
+//
+//
+// ## Multi-threading
+//
+// TODO
+//
+//
+// ## Overview of ggml.c
+//
+// TODO
+//
+//
+// ## SIMD optimizations
+//
+// TODO
+//
+//
+// ## Debugging ggml
+//
+// TODO
+//
+//
+
#ifdef __cplusplus
extern "C" {
#endif
typedef uint16_t ggml_fp16_t;
#endif
-float ggml_fp16_to_fp32(ggml_fp16_t x);
+// convert FP16 <-> FP32
+float ggml_fp16_to_fp32(ggml_fp16_t x);
ggml_fp16_t ggml_fp32_to_fp16(float x);
struct ggml_object;
GGML_TYPE_COUNT,
};
+// available tensor operations:
enum ggml_op {
GGML_OP_NONE = 0,
void * mem_buffer; // if NULL, memory will be allocated internally
};
-void ggml_time_init(void);
+void ggml_time_init(void); // call this once at the beginning of the program
int64_t ggml_time_ms(void);
int64_t ggml_time_us(void);
int64_t ggml_cycles(void);
#include <stdint.h>
#include <stdio.h>
-#if defined _MSC_VER
+#if defined _MSC_VER || defined(__MINGW32__)
#include <Windows.h>
typedef volatile LONG atomic_int;
typedef DWORD thread_ret_t;
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
- out = CreateThread(NULL, 0, func, arg, 0, NULL);
- return out != NULL;
+ HANDLE handle = CreateThread(NULL, 0, func, arg, 0, NULL);
+ if (handle == NULL)
+ {
+ return EAGAIN;
+ }
+
+ *out = handle;
+ return 0;
}
static int pthread_join(pthread_t thread, void* unused) {
// timing
//
-#if defined(_MSC_VER)
+#if defined(_MSC_VER) || defined(__MINGW32__)
static int64_t timer_freq;
void ggml_time_init(void) {
LARGE_INTEGER frequency;
GGML_ASSERT(nb00 == sizeof(float));
if (nb10 == sizeof(float)) {
- for (int j = ith; j < n; j += nth) {
+ const int j0 = (n/nth)*ith;
+ const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
+
+ for (int j = j0; j < j1; j++) {
ggml_vec_add_f32(nc,
(float *) ((char *) dst->data + j*nb1),
(float *) ((char *) src0->data + j*nb01),
} break;
case GGML_OP_ADD:
{
- node->n_tasks = 1;
+ node->n_tasks = n_threads;
} break;
case GGML_OP_SUB:
case GGML_OP_MUL:
}
int ggml_cpu_has_neon(void) {
-#if defined(__ARM_NEON__)
+#if defined(__ARM_NEON)
return 1;
#else
return 0;