// ggml helpers
//
-static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
+static void ggml_graph_compute_helper(
+ std::vector<uint8_t> & buf,
+ ggml_cgraph * graph,
+ int n_threads,
+ whisper_abort_callback abort_callback,
+ void * abort_callback_data) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
+ plan.abort_callback = abort_callback;
+ plan.abort_callback_data = abort_callback_data;
+
if (plan.work_size > 0) {
buf.resize(plan.work_size);
plan.work_data = buf.data();
whisper_context & wctx,
whisper_state & wstate,
const int mel_offset,
- const int n_threads) {
+ const int n_threads,
+ whisper_abort_callback abort_callback,
+ void * abort_callback_data) {
const int64_t t_start_us = ggml_time_us();
// conv
ggml_allocr_alloc_graph(alloc, gf);
if (!whisper_encode_external(wstate)) {
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
}
}
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
} else {
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
}
#else
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
#endif
}
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
} else {
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
}
#else
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
#endif
}
const whisper_token * tokens,
const int n_tokens,
const int n_past,
- const int n_threads) {
+ const int n_threads,
+ whisper_abort_callback abort_callback,
+ void * abort_callback_data) {
const int64_t t_start_us = ggml_time_us();
const auto & model = wctx.model;
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
} else {
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
}
#else
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
#endif
}
}
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
- if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
+ if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
log("%s: failed to eval\n", __func__);
return -1;
}
}
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
- if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
+ if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
log("%s: failed to eval\n", __func__);
return -1;
}
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
const int selected_decoder_id = 0;
- if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
log("%s: failed to eval\n", __func__);
return 1;
}
return false;
}
- if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
+ if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
log("%s: failed to eval\n", __func__);
return 1;
}
}
// encode audio features starting at offset seek
- if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
+ if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
log("%s: failed to encode\n", __func__);
return -6;
}
}
WHISPER_PRINT_DEBUG("\n\n");
- if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
log("%s: failed to decode\n", __func__);
return -7;
}
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
- if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
+ if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
log("%s: failed to decode\n", __func__);
return -8;
}
double tsum = 0.0;
// heat-up
- ggml_graph_compute_helper(work, &gf, n_threads);
+ ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
for (int i = 0; i < n_max; ++i) {
const int64_t t0 = ggml_time_us();
- ggml_graph_compute_helper(work, &gf, n_threads);
+ ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
const int64_t t1 = ggml_time_us();