]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add whisper_state + default state on the whisper_context (#523)
authorsandrohanea <redacted>
Sun, 5 Mar 2023 19:42:19 +0000 (20:42 +0100)
committerGitHub <redacted>
Sun, 5 Mar 2023 19:42:19 +0000 (21:42 +0200)
* Added whisper state + default state on the whisper_context

* Fixed some examples and bindings

* Fixed whisper_n_len (which was used in some binding) and added whisper_n_len_from_state

* Fixed comments

* whisper : reuse kv_cache_free() and fix compiler warnings

* whisper : clean-up the API comments

---------

Co-authored-by: Sandro Hanea <redacted>
Co-authored-by: Georgi Gerganov <redacted>
bindings/go/whisper.go
bindings/ruby/ext/ruby_whisper.cpp
examples/addon.node/addon.cpp
examples/main/main.cpp
whisper.cpp
whisper.h

index 78ca07d66f36360898c749471c488fe58bd04b3e..d47f7f76b96b0524f8d1ab7aea0e45b7c124e462 100644 (file)
@@ -20,7 +20,7 @@ extern bool callEncoderBegin(void* user_data);
 // Text segment callback
 // Called on every newly generated text segment
 // Use the whisper_full_...() functions to obtain the text segments
-static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
+static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) {
     if(user_data != NULL && ctx != NULL) {
         callNewSegment(user_data, n_new);
     }
@@ -29,7 +29,7 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void*
 // Encoder begin callback
 // If not NULL, called before the encoder starts
 // If it returns false, the computation is aborted
-static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
+static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) {
     if(user_data != NULL && ctx != NULL) {
         return callEncoderBegin(user_data);
     }
index e7416ba26c4f84b79f2b198d854e7f3eef0666ff..82027d42fa516dba8a8e3468fb9fc4b213f12df8 100644 (file)
@@ -199,7 +199,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
   {
     static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
 
-    rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
+    rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
       bool is_aborted = *(bool*)user_data;
       return !is_aborted;
     };
index 2ef895f5e054371c70c95506c8e97faa00fb4a98..825232736bd6de8281ba4c06cb676ca0ee46c971 100644 (file)
@@ -72,7 +72,7 @@ int timestamp_to_sample(int64_t t, int n_samples) {
     return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
 }
 
-void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
+void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
     const auto & params  = *((whisper_print_user_data *) user_data)->params;
     const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
 
@@ -260,7 +260,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
             {
                 static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
 
-                wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
+                wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
                     bool is_aborted = *(bool*)user_data;
                     return !is_aborted;
                 };
index e1853c6ba854fd0cf7b6a9989573c83ad9680c9e..cd7d928375753bca5f7f546c7349eaf6e2de5893 100644 (file)
@@ -193,7 +193,7 @@ struct whisper_print_user_data {
     const std::vector<std::vector<float>> * pcmf32s;
 };
 
-void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
+void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
     const auto & params  = *((whisper_print_user_data *) user_data)->params;
     const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
 
@@ -608,7 +608,7 @@ int main(int argc, char ** argv) {
             {
                 static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
 
-                wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
+                wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
                     bool is_aborted = *(bool*)user_data;
                     return !is_aborted;
                 };
index b4a04073299933460926d205fce2c185a8276fd4..c8a904bb040e9c8c32018a3cd44aedb86711a1f6 100644 (file)
@@ -547,13 +547,11 @@ struct whisper_decoder {
     std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
 };
 
-struct whisper_context {
-    int64_t t_load_us   = 0;
-    int64_t t_mel_us    = 0;
+struct whisper_state {
     int64_t t_sample_us = 0;
     int64_t t_encode_us = 0;
     int64_t t_decode_us = 0;
-    int64_t t_start_us  = 0;
+    int64_t t_mel_us = 0;
 
     int32_t n_sample = 0; // number of tokens sampled
     int32_t n_encode = 0; // number of encoder calls
@@ -561,16 +559,10 @@ struct whisper_context {
     int32_t n_fail_p = 0; // number of logprob threshold failures
     int32_t n_fail_h = 0; // number of entropy threshold failures
 
-    ggml_type wtype; // weight type (FP32 or FP16)
-
-    whisper_mel mel;
-
-    whisper_model model;
-    whisper_vocab vocab;
-
     // cross-attention KV cache for the decoders
     // shared between all decoders
     whisper_kv_cache kv_cross;
+    whisper_mel mel;
 
     whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
 
@@ -635,6 +627,18 @@ struct whisper_context {
     }
 };
 
+struct whisper_context {
+    int64_t t_load_us = 0;
+    int64_t t_start_us = 0;
+
+
+    ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16)
+
+    whisper_model model;
+    whisper_vocab vocab;
+    whisper_state * state = nullptr;
+};
+
 template<typename T>
 static void read_safe(whisper_model_loader * loader, T & dest) {
     loader->read(loader->context, &dest, sizeof(T));
@@ -821,32 +825,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         wctx.model.buf = new std::vector<uint8_t>();
         wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
 
-        if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
-            fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
-            return false;
-        }
-
-        {
-            const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
-            fprintf(stderr, "%s: kv self size  = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
-        }
-
-        if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
-            fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
-            return false;
-        }
-
-        {
-            const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
-            fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
-        }
-
-        wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
-
-        wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
-        wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
-        wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
-        wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
+        // we skip initialization of the state until it is needed
+        // because it might be that state will always be provided externally.
     }
 
     // load mel filters
@@ -929,17 +909,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 vocab.id_to_token[i] = word;
             }
         }
-
-        wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
-
-        wctx.logits_id.reserve(n_vocab);
-
-        // TAGS: WHISPER_DECODER_INIT
-        wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
-
-        wctx.decoders[0].probs.reserve   (vocab.n_vocab);
-        wctx.decoders[0].logits.reserve  (vocab.n_vocab);
-        wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
     }
 
     size_t ctx_size = 0;
@@ -1339,33 +1308,34 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
     }
 
-    wctx.rng = std::mt19937(0);
-
     wctx.t_load_us = ggml_time_us() - t_start_us;
 
     return true;
 }
 
-// evaluate the encoder
+// evaluate the encoder with the given state
 //
 // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
 // part of the transformer model and returns the encoded features
 //
-//   - model:      the model
+//   - wctx:      the model
+//   - wstate:     the state of the encoder
 //   - n_threads:  number of threads to use
 //   - mel_offset: offset in the mel spectrogram (i.e. audio offset)
 //
-static bool whisper_encode(
+static bool whisper_encode_internal(
         whisper_context & wctx,
+          whisper_state & wstate,
               const int   mel_offset,
-              const int   n_threads) {
+              const int   n_threads){
+
     const int64_t t_start_us = ggml_time_us();
 
     const auto & model   = wctx.model;
-    const auto & mel_inp = wctx.mel;
+    const auto & mel_inp = wstate.mel;
     const auto & hparams = model.hparams;
 
-    const int n_ctx   = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
+    const int n_ctx   = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
     const int n_state = hparams.n_audio_state;
     const int n_head  = hparams.n_audio_head;
     const int n_layer = hparams.n_audio_layer;
@@ -1374,12 +1344,12 @@ static bool whisper_encode(
     assert(mel_inp.n_mel == n_mels);
 
     struct ggml_init_params params;
-    params.mem_size   = wctx.buf_compute.size();
-    params.mem_buffer = wctx.buf_compute.data();
+    params.mem_size   = wstate.buf_compute.size();
+    params.mem_buffer = wstate.buf_compute.data();
 
     struct ggml_context * ctx0 = ggml_init(params);
 
-    wctx.use_buf(ctx0, 0);
+    wstate.use_buf(ctx0, 0);
 
     struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
     assert(mel->type == GGML_TYPE_F32);
@@ -1401,30 +1371,30 @@ static bool whisper_encode(
 
     // convolution + gelu
     {
-        wctx.use_buf(ctx0, 1);
+        wstate.use_buf(ctx0, 1);
 
         cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
         cur = ggml_add(ctx0,
-                ggml_repeat(ctx0,
-                    model.e_conv_1_b,
-                    cur),
-                cur);
+            ggml_repeat(ctx0,
+                model.e_conv_1_b,
+                cur),
+            cur);
 
         cur = ggml_gelu(ctx0, cur);
 
-        wctx.use_buf(ctx0, 0);
+        wstate.use_buf(ctx0, 0);
 
         cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
         cur = ggml_add(ctx0,
-                ggml_repeat(ctx0,
-                    model.e_conv_2_b,
-                    cur),
-                cur);
+            ggml_repeat(ctx0,
+                model.e_conv_2_b,
+                cur),
+            cur);
 
         cur = ggml_gelu(ctx0, cur);
     }
 
-    wctx.use_buf(ctx0, 3);
+    wstate.use_buf(ctx0, 3);
 
     // ===================================================================
     // NOTE: experimenting with partial evaluation of the encoder (ignore)
@@ -1439,7 +1409,7 @@ static bool whisper_encode(
     //}
 
     static int iter = 0;
-
+    
     const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
     const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
 
@@ -1459,54 +1429,54 @@ static bool whisper_encode(
 
         // norm
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_norm(ctx0, inpL);
 
             // cur = ln_0_w*cur + ln_0_b
             cur = ggml_add(ctx0,
-                    ggml_mul(ctx0,
-                        ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
-                        cur),
-                    ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
+                ggml_mul(ctx0,
+                    ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
+                    cur),
+                ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
         }
 
         // self-attention
         {
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
-                    layer.attn_q_w,
-                    cur);
+                layer.attn_q_w,
+                cur);
 
             Qcur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        layer.attn_q_b,
-                        Qcur),
-                    Qcur);
+                ggml_repeat(ctx0,
+                    layer.attn_q_b,
+                    Qcur),
+                Qcur);
 
             //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
             // note: no bias for Key
             struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
-                    layer.attn_k_w,
-                    cur);
+                layer.attn_k_w,
+                cur);
 
             //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
             struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
-                    layer.attn_v_w,
-                    cur);
+                layer.attn_v_w,
+                cur);
 
             Vcur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        layer.attn_v_b,
-                        Vcur),
-                    Vcur);
+                ggml_repeat(ctx0,
+                    layer.attn_v_b,
+                    Vcur),
+                Vcur);
 
             // ------
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
 #ifdef WHISPER_USE_FLASH_ATTN
             struct ggml_tensor * Q =
@@ -1583,29 +1553,29 @@ static bool whisper_encode(
 #endif
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             cur = ggml_cpy(ctx0,
-                    KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
+                KQV_merged,
+                ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
         }
 
         // projection
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_mul_mat(ctx0,
-                    layer.attn_ln_1_w,
-                    cur);
+                layer.attn_ln_1_w,
+                cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
-                    cur);
+                ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
+                cur);
         }
 
-        wctx.use_buf(ctx0, 2);
+        wstate.use_buf(ctx0, 2);
 
         // add the input
         cur = ggml_add(ctx0, cur, inpL);
@@ -1616,61 +1586,61 @@ static bool whisper_encode(
         {
             // norm
             {
-                wctx.use_buf(ctx0, 0);
+                wstate.use_buf(ctx0, 0);
 
                 cur = ggml_norm(ctx0, inpFF);
 
-                wctx.use_buf(ctx0, 1);
+                wstate.use_buf(ctx0, 1);
 
                 // cur = mlp_ln_w*cur + mlp_ln_b
                 cur = ggml_add(ctx0,
-                        ggml_mul(ctx0,
-                            ggml_repeat(ctx0, layer.mlp_ln_w, cur),
-                            cur),
-                        ggml_repeat(ctx0, layer.mlp_ln_b, cur));
-            }
+                    ggml_mul(ctx0,
+                        ggml_repeat(ctx0, layer.mlp_ln_w, cur),
+                        cur),
+                    ggml_repeat(ctx0, layer.mlp_ln_b, cur));
+    }
 
 #ifdef WHISPER_USE_FLASH_FF
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_flash_ff(ctx0,
-                    ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
-                    layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
+                ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.wtype, n_state, n_ctx)),
+                layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
 #else
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             // fully connected
             cur = ggml_mul_mat(ctx0,
-                    layer.mlp_0_w,
-                    cur);
+                layer.mlp_0_w,
+                cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, layer.mlp_0_b, cur),
-                    cur);
+                ggml_repeat(ctx0, layer.mlp_0_b, cur),
+                cur);
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             // GELU activation
             cur = ggml_gelu(ctx0, cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             // projection
             cur = ggml_mul_mat(ctx0,
-                    layer.mlp_1_w,
-                    cur);
+                layer.mlp_1_w,
+                cur);
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, layer.mlp_1_b, cur),
-                    cur);
+                ggml_repeat(ctx0, layer.mlp_1_b, cur),
+                cur);
 #endif
-        }
+}
 
-        wctx.use_buf(ctx0, 3);
+        wstate.use_buf(ctx0, 3);
 
         inpL = ggml_add(ctx0, cur, inpFF);
     }
@@ -1679,21 +1649,21 @@ static bool whisper_encode(
 
     // norm
     {
-        wctx.use_buf(ctx0, 0);
+        wstate.use_buf(ctx0, 0);
 
         cur = ggml_norm(ctx0, cur);
 
-        wctx.use_buf(ctx0, 1);
+        wstate.use_buf(ctx0, 1);
 
         // cur = ln_f_g*cur + ln_f_b
         cur = ggml_add(ctx0,
-                ggml_mul(ctx0,
-                    ggml_repeat(ctx0, model.e_ln_w, cur),
-                    cur),
-                ggml_repeat(ctx0, model.e_ln_b, cur));
+            ggml_mul(ctx0,
+                ggml_repeat(ctx0, model.e_ln_w, cur),
+                cur),
+            ggml_repeat(ctx0, model.e_ln_b, cur));
     }
 
-    wctx.use_buf(ctx0, -1);
+    wstate.use_buf(ctx0, -1);
 
     // run the computation
     {
@@ -1701,7 +1671,7 @@ static bool whisper_encode(
         gf.n_threads = n_threads;
 
         ggml_build_forward_expand(&gf, cur);
-        ggml_graph_compute       (ctx0, &gf);
+        ggml_graph_compute(ctx0, &gf);
 
         //ggml_graph_print(&gf);
     }
@@ -1731,34 +1701,34 @@ static bool whisper_encode(
         cur->src1 = nullptr;
 
         for (int il = 0; il < model.hparams.n_text_layer; ++il) {
-            auto & layer = model.layers_decoder[il];
+            auto& layer = model.layers_decoder[il];
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
-            struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
-                    layer.cross_attn_k_w,
-                    cur);
+            struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
+                layer.cross_attn_k_w,
+                cur);
 
-            Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+            Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
-            struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
-                    layer.cross_attn_v_w,
-                    cur);
+            struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
+                layer.cross_attn_v_w,
+                cur);
 
             Vcross = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        layer.cross_attn_v_b,
-                        Vcross),
-                    Vcross);
+                ggml_repeat(ctx0,
+                    layer.cross_attn_v_b,
+                    Vcross),
+                Vcross);
 
-            wctx.use_buf(ctx0, -1);
+            wstate.use_buf(ctx0, -1);
 
-            //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
-            //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
-            struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
-            struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx));
+            //struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+            //struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+            struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
+            struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx));
 
             ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
             ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
@@ -1779,8 +1749,8 @@ static bool whisper_encode(
 
     ggml_free(ctx0);
 
-    wctx.t_encode_us += ggml_time_us() - t_start_us;
-    wctx.n_encode++;
+    wstate.t_encode_us += ggml_time_us() - t_start_us;
+    wstate.n_encode++;
 
     return true;
 }
@@ -1795,8 +1765,9 @@ static bool whisper_encode(
 //   - n_tokens:   number of tokens in the prompt
 //   - n_past:     number of past tokens to prefix the prompt with
 //
-static bool whisper_decode(
+static bool whisper_decode_internal(
         whisper_context & wctx,
+          whisper_state & wstate,
         whisper_decoder & decoder,
     const whisper_token * tokens,
               const int   n_tokens,
@@ -1811,7 +1782,7 @@ static bool whisper_decode(
 
     WHISPER_ASSERT(!!kv_self.ctx);
 
-    auto & logits_out = wctx.logits;
+    auto & logits_out = wstate.logits;
 
     const int n_vocab = hparams.n_vocab;
 
@@ -1821,13 +1792,13 @@ static bool whisper_decode(
     const int n_layer = hparams.n_text_layer;
 
     const int N = n_tokens;
-    const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
+    const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
 
     //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
 
     struct ggml_init_params params;
-    params.mem_size   = wctx.buf_compute.size();
-    params.mem_buffer = wctx.buf_compute.data();
+    params.mem_size   = wstate.buf_compute.size();
+    params.mem_buffer = wstate.buf_compute.data();
 
     struct ggml_context * ctx0 = ggml_init(params);
 
@@ -1842,7 +1813,7 @@ static bool whisper_decode(
         ((int32_t *) position->data)[i] = n_past + i;
     }
 
-    wctx.use_buf(ctx0, 3);
+    wstate.use_buf(ctx0, 3);
 
     // token encoding + position encoding
     struct ggml_tensor * cur =
@@ -1857,7 +1828,7 @@ static bool whisper_decode(
 
         // norm
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_norm(ctx0, inpL);
 
@@ -1871,7 +1842,7 @@ static bool whisper_decode(
 
         // self-attention
         {
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
                     layer.attn_q_w,
@@ -1913,7 +1884,7 @@ static bool whisper_decode(
 
             // ------
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
@@ -1929,12 +1900,12 @@ static bool whisper_decode(
                             n_state/n_head, n_head, n_past + N),
                         0, 2, 1, 3);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             // K * Q
             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             //struct ggml_tensor * KQ_scaled =
             //    ggml_scale(ctx0,
@@ -1944,11 +1915,11 @@ static bool whisper_decode(
 
             struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             struct ggml_tensor * V_trans =
                 ggml_permute(ctx0,
@@ -1957,7 +1928,7 @@ static bool whisper_decode(
                             n_state/n_head, n_head, n_past + N),
                         1, 2, 0, 3);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
 
@@ -1970,31 +1941,31 @@ static bool whisper_decode(
 
         // projection
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_mul_mat(ctx0,
                     layer.attn_ln_1_w,
                     cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
                     cur);
         }
 
-        wctx.use_buf(ctx0, 2);
+        wstate.use_buf(ctx0, 2);
 
         // add the input
         struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
 
         // norm
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             // cur = ln_0_w*cur + ln_0_b
             cur = ggml_add(ctx0,
@@ -2006,7 +1977,7 @@ static bool whisper_decode(
 
         // cross-attention
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
                     layer.cross_attn_q_w,
@@ -2023,19 +1994,19 @@ static bool whisper_decode(
             // Kcross is already scaled
             struct ggml_tensor * Kcross =
                 ggml_reshape_3d(ctx0,
-                        ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
+                        ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
                         n_state/n_head, n_head, M);
 
             struct ggml_tensor * Vcross =
                 ggml_reshape_3d(ctx0,
-                        ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
+                        ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
                         n_state/n_head, n_head, M);
 
             struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
 
             // ------
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
@@ -2046,7 +2017,7 @@ static bool whisper_decode(
 
             struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             // K * Q
             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
@@ -2060,15 +2031,15 @@ static bool whisper_decode(
             // no masking for cross-attention
             //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
@@ -2080,20 +2051,20 @@ static bool whisper_decode(
 
         // projection
         {
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_mul_mat(ctx0,
                     layer.cross_attn_ln_1_w,
                     cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
                     cur);
         }
 
-        wctx.use_buf(ctx0, 2);
+        wstate.use_buf(ctx0, 2);
 
         // add the input
         cur = ggml_add(ctx0, cur, inpCA);
@@ -2104,11 +2075,11 @@ static bool whisper_decode(
         {
             // norm
             {
-                wctx.use_buf(ctx0, 0);
+                wstate.use_buf(ctx0, 0);
 
                 cur = ggml_norm(ctx0, inpFF);
 
-                wctx.use_buf(ctx0, 1);
+                wstate.use_buf(ctx0, 1);
 
                 // cur = mlp_ln_w*cur + mlp_ln_b
                 cur = ggml_add(ctx0,
@@ -2118,39 +2089,39 @@ static bool whisper_decode(
                         ggml_repeat(ctx0, layer.mlp_ln_b, cur));
             }
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             // fully connected
             cur = ggml_mul_mat(ctx0,
                     layer.mlp_0_w,
                     cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0, layer.mlp_0_b, cur),
                     cur);
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             // GELU activation
             cur = ggml_gelu(ctx0, cur);
 
-            wctx.use_buf(ctx0, 1);
+            wstate.use_buf(ctx0, 1);
 
             // projection
             cur = ggml_mul_mat(ctx0,
                     layer.mlp_1_w,
                     cur);
 
-            wctx.use_buf(ctx0, 0);
+            wstate.use_buf(ctx0, 0);
 
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0, layer.mlp_1_b, cur),
                     cur);
         }
 
-        wctx.use_buf(ctx0, 3);
+        wstate.use_buf(ctx0, 3);
 
         inpL = ggml_add(ctx0, cur, inpFF);
     }
@@ -2159,11 +2130,11 @@ static bool whisper_decode(
 
     // norm
     {
-        wctx.use_buf(ctx0, 0);
+        wstate.use_buf(ctx0, 0);
 
         cur = ggml_norm(ctx0, cur);
 
-        wctx.use_buf(ctx0, 1);
+        wstate.use_buf(ctx0, 1);
 
         cur = ggml_add(ctx0,
                 ggml_mul(ctx0,
@@ -2172,7 +2143,7 @@ static bool whisper_decode(
                 ggml_repeat(ctx0, model.d_ln_b, cur));
     }
 
-    wctx.use_buf(ctx0, 0);
+    wstate.use_buf(ctx0, 0);
 
     // compute logits only for the last token
     // comment this line to compute logits for all N tokens
@@ -2181,7 +2152,7 @@ static bool whisper_decode(
 
     struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
 
-    wctx.use_buf(ctx0, -1);
+    wstate.use_buf(ctx0, -1);
 
     // run the computation
     {
@@ -2208,8 +2179,8 @@ static bool whisper_decode(
 
     ggml_free(ctx0);
 
-    wctx.t_decode_us += ggml_time_us() - t_start_us;
-    wctx.n_decode++;
+    wstate.t_decode_us += ggml_time_us() - t_start_us;
+    wstate.n_decode++;
 
     return true;
 }
@@ -2313,7 +2284,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
 
 // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
 static bool log_mel_spectrogram(
-        whisper_context & wctx,
+          whisper_state & wstate,
             const float * samples,
               const int   n_samples,
               const int   /*sample_rate*/,
@@ -2433,7 +2404,7 @@ static bool log_mel_spectrogram(
         mel.data[i] = (mel.data[i] + 4.0)/4.0;
     }
 
-    wctx.t_mel_us += ggml_time_us() - t_start_us;
+    wstate.t_mel_us += ggml_time_us() - t_start_us;
 
     return true;
 }
@@ -2507,7 +2478,56 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
 // interface implementation
 //
 
-struct whisper_context * whisper_init_from_file(const char * path_model) {
+struct whisper_state * whisper_init_state(whisper_context * ctx) {
+    whisper_state * state = new whisper_state;
+
+    const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
+
+
+    if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
+        fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
+        return nullptr;
+    }
+
+    {
+        const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
+        fprintf(stderr, "%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+    }
+
+    if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) {
+        fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
+        return nullptr;
+    }
+
+    {
+        const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
+        fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+    }
+
+
+    state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
+
+    state->logits_id.reserve(ctx->model.hparams.n_vocab);
+
+    // TAGS: WHISPER_DECODER_INIT
+    state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
+
+    state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
+    state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
+    state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
+    state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
+
+    state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
+    state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
+    state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
+    state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
+
+    state->rng = std::mt19937(0);
+
+    return state;
+}
+
+struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
     whisper_model_loader loader = {};
 
     fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
@@ -2535,10 +2555,10 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
         fin->close();
     };
 
-    return whisper_init(&loader);
+    return whisper_init_no_state(&loader);
 }
 
-struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
+struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
     struct buf_context {
         uint8_t* buffer;
         size_t size;
@@ -2571,10 +2591,10 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
 
     loader.close = [](void * /*ctx*/) { };
 
-    return whisper_init(&loader);
+    return whisper_init_no_state(&loader);
 }
 
-struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
+struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
     ggml_time_init();
 
     whisper_context * ctx = new whisper_context;
@@ -2591,6 +2611,64 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
     return ctx;
 }
 
+struct whisper_context * whisper_init_from_file(const char * path_model) {
+    whisper_context * ctx = whisper_init_from_file_no_state(path_model);
+    if (!ctx) {
+        return nullptr;
+    }
+
+    ctx->state = whisper_init_state(ctx);
+    if (!ctx->state) {
+        whisper_free(ctx);
+        return nullptr;
+    }
+
+    return ctx;
+}
+
+struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
+    whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size);
+    if (!ctx) {
+        return nullptr;
+    }
+
+    ctx->state = whisper_init_state(ctx);
+    if (!ctx->state) {
+        whisper_free(ctx);
+        return nullptr;
+    }
+
+    return ctx;
+}
+
+struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
+    whisper_context * ctx = whisper_init_no_state(loader);
+    if (!ctx) {
+        return nullptr;
+    }
+
+    ctx->state = whisper_init_state(ctx);
+    if (!ctx->state) {
+        whisper_free(ctx);
+        return nullptr;
+    }
+
+    return ctx;
+}
+
+void whisper_free_state(struct whisper_state * state)
+{
+    if (state) {
+        kv_cache_free(state->kv_cross);
+
+        for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
+            kv_cache_free(state->decoders[i].kv_self);
+        }
+
+        delete state;
+    }
+}
+
 void whisper_free(struct whisper_context * ctx) {
     if (ctx) {
         if (ctx->model.ctx) {
@@ -2599,20 +2677,29 @@ void whisper_free(struct whisper_context * ctx) {
         if (ctx->model.buf) {
             delete ctx->model.buf;
         }
-        if (ctx->kv_cross.ctx) {
-            ggml_free(ctx->kv_cross.ctx);
-        }
-        for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
-            if (ctx->decoders[i].kv_self.ctx) {
-                ggml_free(ctx->decoders[i].kv_self.ctx);
-            }
-        }
+
+        whisper_free_state(ctx->state);
+
         delete ctx;
     }
 }
 
+int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
+    if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
+        fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
+        return -1;
+    }
+
+    return 0;
+}
+
 int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
-    if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
+    return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
+}
+
+// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
+int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
+    if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) {
         fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
         return -1;
     }
@@ -2622,11 +2709,26 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
 
 // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
 int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
-    if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
-        fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
+    return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
+}
+
+int whisper_set_mel_with_state(
+        struct whisper_context * /*ctx*/,
+          struct whisper_state * state,
+                   const float * data,
+                           int   n_len,
+                           int   n_mel) {
+    if (n_mel != WHISPER_N_MEL) {
+        fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
         return -1;
     }
 
+    state->mel.n_len = n_len;
+    state->mel.n_mel = n_mel;
+
+    state->mel.data.resize(n_len*n_mel);
+    memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
+
     return 0;
 }
 
@@ -2635,22 +2737,20 @@ int whisper_set_mel(
         const float * data,
         int n_len,
         int n_mel) {
-    if (n_mel != WHISPER_N_MEL) {
-        fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
+    return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel);
+}
+
+int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
+    if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
+        fprintf(stderr, "%s: failed to eval\n", __func__);
         return -1;
     }
 
-    ctx->mel.n_len = n_len;
-    ctx->mel.n_mel = n_mel;
-
-    ctx->mel.data.resize(n_len*n_mel);
-    memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
-
     return 0;
 }
 
 int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
-    if (!whisper_encode(*ctx, offset, n_threads)) {
+    if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
         fprintf(stderr, "%s: failed to eval\n", __func__);
         return -1;
     }
@@ -2658,11 +2758,28 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
     return 0;
 }
 
+int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
+    const int selected_decoder_id = 0;
+
+    if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
+        fprintf(stderr, "%s: failed to eval\n", __func__);
+        return 1;
+    }
+
+    return 0;
+}
+
 int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
-    // TODO: add selected_decoder_id to context
+    // TODO: add selected_decoder_id to state
     const int selected_decoder_id = 0;
 
-    if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
+    if (ctx->state == nullptr) {
+        fprintf(stderr, "%s: ERROR state was not loaded.\n", __func__);
+        return false;
+    }
+
+
+    if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
         fprintf(stderr, "%s: failed to eval\n", __func__);
         return 1;
     }
@@ -2720,11 +2837,12 @@ const char * whisper_lang_str(int id) {
     return nullptr;
 }
 
-int whisper_lang_auto_detect(
+int whisper_lang_auto_detect_with_state(
         struct whisper_context * ctx,
-        int offset_ms,
-        int n_threads,
-        float * lang_probs) {
+          struct whisper_state * state,
+                           int   offset_ms,
+                           int   n_threads,
+                         float * lang_probs) {
     const int seek = offset_ms/10;
 
     if (seek < 0) {
@@ -2732,8 +2850,8 @@ int whisper_lang_auto_detect(
         return -1;
     }
 
-    if (seek >= ctx->mel.n_len) {
-        fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
+    if (seek >= state->mel.n_len) {
+        fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len*10);
         return -2;
     }
 
@@ -2745,17 +2863,17 @@ int whisper_lang_auto_detect(
 
     const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
 
-    if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
+    if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
         fprintf(stderr, "%s: failed to decode\n", __func__);
         return -7;
     }
 
-    auto & logits_id = ctx->logits_id;
+    auto & logits_id = state->logits_id;
     logits_id.clear();
 
     for (const auto & kv : g_lang) {
         const auto token_lang = whisper_token_lang(ctx, kv.second.first);
-        logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
+        logits_id.emplace_back(state->logits[token_lang], kv.second.first);
     }
 
     // sort descending
@@ -2794,8 +2912,20 @@ int whisper_lang_auto_detect(
     return logits_id[0].second;
 }
 
+int whisper_lang_auto_detect(
+        struct whisper_context * ctx,
+                           int   offset_ms,
+                           int   n_threads,
+                         float * lang_probs) {
+    return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
+}
+
+int whisper_n_len_from_state(struct whisper_state * state) {
+    return state->mel.n_len;
+}
+
 int whisper_n_len(struct whisper_context * ctx) {
-    return ctx->mel.n_len;
+    return ctx->state->mel.n_len;
 }
 
 int whisper_n_vocab(struct whisper_context * ctx) {
@@ -2815,7 +2945,12 @@ int whisper_is_multilingual(struct whisper_context * ctx) {
 }
 
 float * whisper_get_logits(struct whisper_context * ctx) {
-    return ctx->logits.data();
+    return ctx->state->logits.data();
+}
+
+
+float * whisper_get_logits_from_state(struct whisper_state * state) {
+    return state->logits.data();
 }
 
 const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
@@ -2861,24 +2996,29 @@ whisper_token whisper_token_transcribe(void) {
 void whisper_print_timings(struct whisper_context * ctx) {
     const int64_t t_end_us = ggml_time_us();
 
-    const int32_t n_sample = std::max(1, ctx->n_sample);
-    const int32_t n_encode = std::max(1, ctx->n_encode);
-    const int32_t n_decode = std::max(1, ctx->n_decode);
-
     fprintf(stderr, "\n");
-    fprintf(stderr, "%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h);
-    fprintf(stderr, "%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
-    fprintf(stderr, "%s:      mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
-    fprintf(stderr, "%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample);
-    fprintf(stderr, "%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode);
-    fprintf(stderr, "%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode);
+    fprintf(stderr, "%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
+    if (ctx->state != nullptr) {
+
+        const int32_t n_sample = std::max(1, ctx->state->n_sample);
+        const int32_t n_encode = std::max(1, ctx->state->n_encode);
+        const int32_t n_decode = std::max(1, ctx->state->n_decode);
+
+        fprintf(stderr, "%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
+        fprintf(stderr, "%s:      mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
+        fprintf(stderr, "%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
+        fprintf(stderr, "%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
+        fprintf(stderr, "%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
+    }
     fprintf(stderr, "%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
 }
 
 void whisper_reset_timings(struct whisper_context * ctx) {
-    ctx->t_sample_us = 0;
-    ctx->t_encode_us = 0;
-    ctx->t_decode_us = 0;
+    if (ctx->state != nullptr) {
+        ctx->state->t_sample_us = 0;
+        ctx->state->t_encode_us = 0;
+        ctx->state->t_decode_us = 0;
+    }
 }
 
 const char * whisper_print_system_info(void) {
@@ -2991,6 +3131,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
 static void whisper_exp_compute_token_level_timestamps(
         struct whisper_context & ctx,
+          struct whisper_state & state,
                            int   i_segment,
                          float   thold_pt,
                          float   thold_ptsum);
@@ -3023,8 +3164,8 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
 
 // wrap the last segment to max_len characters
 // returns the number of new segments
-static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
-    auto segment = ctx.result_all.back();
+static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
+    auto segment = state.result_all.back();
 
     int res = 1;
     int acc = 0;
@@ -3046,24 +3187,24 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
                 trim(text);
             }
 
-            ctx.result_all.back().text = std::move(text);
-            ctx.result_all.back().t1 = token.t0;
-            ctx.result_all.back().tokens.resize(i);
+            state.result_all.back().text = std::move(text);
+            state.result_all.back().t1 = token.t0;
+            state.result_all.back().tokens.resize(i);
 
-            ctx.result_all.push_back({});
-            ctx.result_all.back().t0 = token.t0;
-            ctx.result_all.back().t1 = segment.t1;
+            state.result_all.push_back({});
+            state.result_all.back().t0 = token.t0;
+            state.result_all.back().t1 = segment.t1;
 
             // add tokens [i, end] to the new segment
-            ctx.result_all.back().tokens.insert(
-                    ctx.result_all.back().tokens.end(),
+            state.result_all.back().tokens.insert(
+                state.result_all.back().tokens.end(),
                     segment.tokens.begin() + i,
                     segment.tokens.end());
 
             acc = 0;
             text = "";
 
-            segment = ctx.result_all.back();
+            segment = state.result_all.back();
             i = -1;
 
             res++;
@@ -3076,7 +3217,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
     if (split_on_word) {
         trim(text);
     }
-    ctx.result_all.back().text = std::move(text);
+    state.result_all.back().text = std::move(text);
 
     return res;
 }
@@ -3093,6 +3234,7 @@ static const std::vector<std::string> non_speech_tokens = {
 // - computes logprobs and probs
 static void whisper_process_logits(
               struct whisper_context & ctx,
+               struct whisper_state  & state,
     const struct whisper_full_params   params,
               struct whisper_decoder & decoder,
                                float   temperature) {
@@ -3111,7 +3253,7 @@ static void whisper_process_logits(
     auto & logprobs = decoder.logprobs;
     {
         logits.resize(n_logits);
-        memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
+        memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
 
         if (temperature > 0.0f) {
             for (int i = 0; i < n_logits; i++) {
@@ -3149,7 +3291,7 @@ static void whisper_process_logits(
         logits[vocab.token_transcribe] = -INFINITY;
 
         if (params.logits_filter_callback) {
-            params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
+            params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
         }
 
         // suppress non-speech tokens
@@ -3310,6 +3452,7 @@ static void whisper_process_logits(
 
 static whisper_token_data whisper_sample_token(
             whisper_context & ctx,
+              whisper_state & state,
       const whisper_decoder & decoder,
                        bool   best) {
     whisper_token_data result = {
@@ -3354,7 +3497,7 @@ static whisper_token_data whisper_sample_token(
     } else {
         std::discrete_distribution<> dist(probs.begin(), probs.end());
 
-        result.id   = dist(ctx.rng);
+        result.id   = dist(state.rng);
         result.p    = probs[result.id];
         result.plog = logprobs[result.id];
     }
@@ -3364,13 +3507,14 @@ static whisper_token_data whisper_sample_token(
         result.pt  = result.p;
     }
 
-    ctx.n_sample++;
+    state.n_sample++;
 
     return result;
 }
 
 static std::vector<whisper_token_data> whisper_sample_token_topk(
             whisper_context & ctx,
+              whisper_state & state,
       const whisper_decoder & decoder,
                         int   k) {
     const auto & vocab = ctx.vocab;
@@ -3381,7 +3525,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
 
     const int n_logits = vocab.n_vocab;
 
-    auto & logits_id = ctx.logits_id;
+    auto & logits_id = state.logits_id;
 
     logits_id.clear();
     for (int i = 0; i < n_logits; ++i) {
@@ -3434,7 +3578,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
         }
     }
 
-    ctx.n_sample++;
+    state.n_sample++;
 
     return result;
 }
@@ -3488,24 +3632,25 @@ static void whisper_sequence_score(
     }
 }
 
-int whisper_full(
+int whisper_full_with_state(
         struct whisper_context * ctx,
-        struct whisper_full_params params,
-        const float * samples,
-        int n_samples) {
+          struct whisper_state * state,
+    struct whisper_full_params   params,
+                   const float * samples,
+                           int   n_samples) {
     // clear old results
-    auto & result_all = ctx->result_all;
+    auto & result_all = state->result_all;
 
     result_all.clear();
 
     // compute log mel spectrogram
     if (params.speed_up) {
-        if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
+        if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
             fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
             return -1;
         }
     } else {
-        if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
+        if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
             fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
             return -2;
         }
@@ -3515,26 +3660,26 @@ int whisper_full(
     if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
         std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
 
-        const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
+        const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
         if (lang_id < 0) {
             fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
             return -3;
         }
-        ctx->lang_id = lang_id;
+        state->lang_id = lang_id;
         params.language = whisper_lang_str(lang_id);
 
         fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
     }
 
     if (params.token_timestamps) {
-        ctx->t_beg    = 0;
-        ctx->t_last   = 0;
-        ctx->tid_last = 0;
-        ctx->energy = get_signal_energy(samples, n_samples, 32);
+        state->t_beg    = 0;
+        state->t_last   = 0;
+        state->tid_last = 0;
+        state->energy = get_signal_energy(samples, n_samples, 32);
     }
 
     const int seek_start = params.offset_ms/10;
-    const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
+    const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len_from_state(state) : params.duration_ms/10);
 
     // if length of spectrogram is less than 1s (100 samples), then return
     // basically don't process anything that is less than 1s
@@ -3572,10 +3717,10 @@ int whisper_full(
 
     // TAGS: WHISPER_DECODER_INIT
     for (int j = 1; j < n_decoders; j++) {
-        auto & decoder = ctx->decoders[j];
+        auto & decoder = state->decoders[j];
 
         if (decoder.kv_self.ctx == nullptr) {
-            decoder.kv_self = ctx->decoders[0].kv_self;
+            decoder.kv_self = state->decoders[0].kv_self;
             if (!kv_cache_reinit(decoder.kv_self)) {
                 fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
                 return -4;
@@ -3583,7 +3728,7 @@ int whisper_full(
 
             WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
 
-            decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
+            decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
 
             decoder.probs.resize   (ctx->vocab.n_vocab);
             decoder.logits.resize  (ctx->vocab.n_vocab);
@@ -3592,7 +3737,7 @@ int whisper_full(
     }
 
     // the accumulated text context so far
-    auto & prompt_past = ctx->prompt_past;
+    auto & prompt_past = state->prompt_past;
     if (params.no_context) {
         prompt_past.clear();
     }
@@ -3611,13 +3756,13 @@ int whisper_full(
         fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
         return -5;
     }
-    ctx->exp_n_audio_ctx = params.audio_ctx;
+    state->exp_n_audio_ctx = params.audio_ctx;
 
     // these tokens determine the task that will be performed
     std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
     if (whisper_is_multilingual(ctx)) {
         const int lang_id = whisper_lang_id(params.language);
-        ctx->lang_id = lang_id;
+        state->lang_id = lang_id;
         prompt_init.push_back(whisper_token_lang(ctx, lang_id));
         if (params.translate) {
             prompt_init.push_back(whisper_token_translate());
@@ -3669,14 +3814,14 @@ int whisper_full(
         }
 
         if (params.encoder_begin_callback) {
-            if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
+            if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
                 fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
                 break;
             }
         }
 
         // encode audio features starting at offset seek
-        if (!whisper_encode(*ctx, seek, params.n_threads)) {
+        if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
             fprintf(stderr, "%s: failed to encode\n", __func__);
             return -6;
         }
@@ -3717,7 +3862,7 @@ int whisper_full(
 
             // TAGS: WHISPER_DECODER_INIT
             for (int j = 0; j < n_decoders_cur; ++j) {
-                auto & decoder = ctx->decoders[j];
+                auto & decoder = state->decoders[j];
 
                 decoder.kv_self.n = 0;
 
@@ -3759,7 +3904,7 @@ int whisper_full(
                 }
                 WHISPER_PRINT_DEBUG("\n\n");
 
-                if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
+                if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
                     fprintf(stderr, "%s: failed to decode\n", __func__);
                     return -7;
                 }
@@ -3767,24 +3912,24 @@ int whisper_full(
                 {
                     const int64_t t_start_sample_us = ggml_time_us();
 
-                    whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
+                    whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
 
-                    ctx->decoders[0].kv_self.n += prompt.size();
+                    state->decoders[0].kv_self.n += prompt.size();
 
                     for (int j = 1; j < n_decoders_cur; ++j) {
-                        auto & decoder = ctx->decoders[j];
+                        auto & decoder = state->decoders[j];
 
-                        memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
-                        memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
+                        memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
+                        memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
 
                         decoder.kv_self.n += prompt.size();
 
-                        memcpy(decoder.probs.data(),    ctx->decoders[0].probs.data(),    decoder.probs.size()*sizeof(decoder.probs[0]));
-                        memcpy(decoder.logits.data(),   ctx->decoders[0].logits.data(),   decoder.logits.size()*sizeof(decoder.logits[0]));
-                        memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
+                        memcpy(decoder.probs.data(), state->decoders[0].probs.data(),    decoder.probs.size()*sizeof(decoder.probs[0]));
+                        memcpy(decoder.logits.data(), state->decoders[0].logits.data(),   decoder.logits.size()*sizeof(decoder.logits[0]));
+                        memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
                     }
 
-                    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+                    state->t_sample_us += ggml_time_us() - t_start_sample_us;
                 }
             }
 
@@ -3795,7 +3940,7 @@ int whisper_full(
                 if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
                     kv_bufs.resize(n_decoders_cur);
                     for (int j = 0; j < n_decoders_cur; ++j) {
-                        auto & decoder = ctx->decoders[j];
+                        auto & decoder = state->decoders[j];
 
                         if (decoder.completed || decoder.failed) {
                             continue;
@@ -3813,7 +3958,7 @@ int whisper_full(
 
                 // generate new sequence candidates for each decoder
                 for (int j = 0; j < n_decoders_cur; ++j) {
-                    auto & decoder = ctx->decoders[j];
+                    auto & decoder = state->decoders[j];
 
                     if (decoder.completed || decoder.failed) {
                         continue;
@@ -3823,16 +3968,16 @@ int whisper_full(
                         case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
                             {
                                 if (t_cur < 1e-6f) {
-                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
+                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
                                 } else {
-                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
+                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
                                 }
 
                                 decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
                             } break;
                         case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
                             {
-                                const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
+                                const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
 
                                 for (const auto & token : tokens_new) {
                                     beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
@@ -3857,7 +4002,7 @@ int whisper_full(
                     uint32_t cur_c = 0;
 
                     for (int j = 0; j < n_decoders_cur; ++j) {
-                        auto & decoder = ctx->decoders[j];
+                        auto & decoder = state->decoders[j];
 
                         if (decoder.completed || decoder.failed) {
                             continue;
@@ -3886,7 +4031,7 @@ int whisper_full(
                 // - check if the sequence is failed
                 // - update sliding window based on timestamp tokens
                 for (int j = 0; j < n_decoders_cur; ++j) {
-                    auto & decoder = ctx->decoders[j];
+                    auto & decoder = state->decoders[j];
 
                     if (decoder.completed || decoder.failed) {
                         continue;
@@ -3968,7 +4113,7 @@ int whisper_full(
                     bool completed_all = true;
 
                     for (int j = 0; j < n_decoders_cur; ++j) {
-                        auto & decoder = ctx->decoders[j];
+                        auto & decoder = state->decoders[j];
 
                         if (decoder.completed || decoder.failed) {
                             continue;
@@ -3982,11 +4127,11 @@ int whisper_full(
                     }
                 }
 
-                ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+                state->t_sample_us += ggml_time_us() - t_start_sample_us;
 
                 // obtain logits for the next token
                 for (int j = 0; j < n_decoders_cur; ++j) {
-                    auto & decoder = ctx->decoders[j];
+                    auto & decoder = state->decoders[j];
 
                     if (decoder.failed || decoder.completed) {
                         continue;
@@ -3997,7 +4142,7 @@ int whisper_full(
 
                     //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
 
-                    if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
+                    if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
                         fprintf(stderr, "%s: failed to decode\n", __func__);
                         return -8;
                     }
@@ -4005,11 +4150,11 @@ int whisper_full(
                     {
                         const int64_t t_start_sample_us = ggml_time_us();
 
-                        whisper_process_logits(*ctx, params, decoder, t_cur);
+                        whisper_process_logits(*ctx, *state, params, decoder, t_cur);
 
                         ++decoder.kv_self.n;
 
-                        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+                        state->t_sample_us += ggml_time_us() - t_start_sample_us;
                     }
                 }
             }
@@ -4019,7 +4164,7 @@ int whisper_full(
                 double best_score = -INFINITY;
 
                 for (int j = 0; j < n_decoders_cur; ++j) {
-                    auto & decoder = ctx->decoders[j];
+                    auto & decoder = state->decoders[j];
 
                     if (decoder.failed) {
                         continue;
@@ -4036,7 +4181,7 @@ int whisper_full(
                                 __func__, j, decoder.sequence.entropy, params.entropy_thold);
 
                         decoder.failed = true;
-                        ctx->n_fail_h++;
+                        state->n_fail_h++;
 
                         continue;
                     }
@@ -4054,11 +4199,11 @@ int whisper_full(
             {
                 bool success = true;
 
-                const auto & decoder = ctx->decoders[best_decoder_id];
+                const auto & decoder = state->decoders[best_decoder_id];
 
                 if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
                     success = false;
-                    ctx->n_fail_p++;
+                    state->n_fail_p++;
                 }
 
                 if (success) {
@@ -4075,7 +4220,7 @@ int whisper_full(
 
         // output results through a user-provided callback
         {
-            const auto & best_decoder = ctx->decoders[best_decoder_id];
+            const auto & best_decoder = state->decoders[best_decoder_id];
 
             const auto seek_delta = best_decoder.seek_delta;
             const auto result_len = best_decoder.sequence.result_len;
@@ -4138,14 +4283,14 @@ int whisper_full(
 
                             if (params.token_timestamps) {
                                 whisper_exp_compute_token_level_timestamps(
-                                        *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+                                        *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
 
                                 if (params.max_len > 0) {
-                                    n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
+                                    n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
                                 }
                             }
                             if (params.new_segment_callback) {
-                                params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+                                params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
                             }
                         }
                         text = "";
@@ -4182,14 +4327,14 @@ int whisper_full(
 
                     if (params.token_timestamps) {
                         whisper_exp_compute_token_level_timestamps(
-                                *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+                                *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
 
                         if (params.max_len > 0) {
-                            n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
+                            n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
                         }
                     }
                     if (params.new_segment_callback) {
-                        params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+                        params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
                     }
                 }
             }
@@ -4204,6 +4349,15 @@ int whisper_full(
     return 0;
 }
 
+
+int whisper_full(
+        struct whisper_context * ctx,
+    struct whisper_full_params   params,
+                   const float * samples,
+                           int   n_samples) {
+    return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
+}
+
 int whisper_full_parallel(
         struct whisper_context * ctx,
         struct whisper_full_params params,
@@ -4213,40 +4367,10 @@ int whisper_full_parallel(
     if (n_processors == 1) {
         return whisper_full(ctx, params, samples, n_samples);
     }
-
     int ret = 0;
 
-    // prepare separate contexts for each thread
-    std::vector<struct whisper_context> ctxs(n_processors - 1);
-
-    for (int i = 0; i < n_processors - 1; ++i) {
-        auto & ctx_p = ctxs[i];
-
-        ctx_p = *ctx;
-
-        ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
-
-        ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
-
-        if (!kv_cache_reinit(ctx_p.kv_cross)) {
-            fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
-            return false;
-        }
-
-        // TAGS: WHISPER_DECODER_INIT
-        for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
-            if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
-                fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
-                return false;
-            }
-
-            ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
-
-            ctx_p.decoders[j].probs.reserve   (ctx_p.vocab.n_vocab);
-            ctx_p.decoders[j].logits.reserve  (ctx_p.vocab.n_vocab);
-            ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
-        }
-    }
+    // prepare separate states for each thread
+    std::vector<whisper_state*> states;
 
     const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
     const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
@@ -4256,6 +4380,9 @@ int whisper_full_parallel(
 
     std::vector<std::thread> workers(n_processors - 1);
     for (int i = 0; i < n_processors - 1; ++i) {
+        // create a new state for each thread
+        states.push_back(whisper_init_state(ctx));
+
         const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
         const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
 
@@ -4268,13 +4395,17 @@ int whisper_full_parallel(
         params_cur.new_segment_callback = nullptr;
         params_cur.new_segment_callback_user_data = nullptr;
 
-        workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
+        workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
     }
 
     {
         auto params_cur = params;
 
-        ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
+        // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk.
+        params_cur.print_realtime = false;
+
+        // Run the first transformation using default state but only for the first chunk.
+        ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
     }
 
     for (int i = 0; i < n_processors - 1; ++i) {
@@ -4283,45 +4414,43 @@ int whisper_full_parallel(
 
     const int64_t offset_t = (int64_t) params.offset_ms/10.0;
 
-    // combine results into ctx->result_all
+    // combine results into result_state->result_all from all other states
     for (int i = 0; i < n_processors - 1; ++i) {
-        auto & results_i = ctxs[i].result_all;
+        auto& results_i = states[i]->result_all;
 
-        for (auto & result : results_i) {
+        for (auto& result : results_i) {
             // correct the segment timestamp taking into account the offset
-            result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
-            result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
+            result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
+            result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
+
 
             // make sure that segments are not overlapping
-            if (!ctx->result_all.empty()) {
-                result.t0 = std::max(result.t0, ctx->result_all.back().t1);
+            if (!ctx->state->result_all.empty()) {
+                result.t0 = std::max(result.t0, ctx->state->result_all.back().t1);
             }
 
-            ctx->result_all.push_back(std::move(result));
+            ctx->state->result_all.push_back(std::move(result));
 
             // call the new_segment_callback for each segment
             if (params.new_segment_callback) {
-                params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
+                params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data);
             }
         }
 
-        ctx->t_mel_us    += ctxs[i].t_mel_us;
-        ctx->t_sample_us += ctxs[i].t_sample_us;
-        ctx->t_encode_us += ctxs[i].t_encode_us;
-        ctx->t_decode_us += ctxs[i].t_decode_us;
+        ctx->state->t_mel_us += states[i]->t_mel_us;
 
-        kv_cache_free(ctx->kv_cross);
+        ctx->state->t_sample_us += states[i]->t_sample_us;
+        ctx->state->t_encode_us += states[i]->t_encode_us;
+        ctx->state->t_decode_us += states[i]->t_decode_us;
 
-        for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
-            kv_cache_free(ctx->decoders[j].kv_self);
-        }
+        whisper_free_state(states[i]);
     }
 
     // average the timings
-    ctx->t_mel_us    /= n_processors;
-    ctx->t_sample_us /= n_processors;
-    ctx->t_encode_us /= n_processors;
-    ctx->t_decode_us /= n_processors;
+    ctx->state->t_mel_us    /= n_processors;
+    ctx->state->t_sample_us /= n_processors;
+    ctx->state->t_encode_us /= n_processors;
+    ctx->state->t_decode_us /= n_processors;
 
     // print information about the audio boundaries
     fprintf(stderr, "\n");
@@ -4334,44 +4463,84 @@ int whisper_full_parallel(
     return ret;
 }
 
+int whisper_full_n_segments_from_state(struct whisper_state * state) {
+    return state->result_all.size();
+}
+
 int whisper_full_n_segments(struct whisper_context * ctx) {
-    return ctx->result_all.size();
+    return ctx->state->result_all.size();
+}
+
+int whisper_full_lang_id_from_state(struct whisper_state * state) {
+    return state->lang_id;
 }
 
 int whisper_full_lang_id(struct whisper_context * ctx) {
-    return ctx->lang_id;
+    return ctx->state->lang_id;
+}
+
+int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
+    return state->result_all[i_segment].t0;
 }
 
 int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
-    return ctx->result_all[i_segment].t0;
+    return ctx->state->result_all[i_segment].t0;
+}
+
+int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
+    return state->result_all[i_segment].t1;
 }
 
 int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
-    return ctx->result_all[i_segment].t1;
+    return ctx->state->result_all[i_segment].t1;
+}
+
+const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
+    return state->result_all[i_segment].text.c_str();
 }
 
 const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
-    return ctx->result_all[i_segment].text.c_str();
+    return ctx->state->result_all[i_segment].text.c_str();
+}
+
+int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) {
+    return state->result_all[i_segment].tokens.size();
 }
 
 int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
-    return ctx->result_all[i_segment].tokens.size();
+    return ctx->state->result_all[i_segment].tokens.size();
 }
 
-const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
-    return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
+const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) {
+    return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str();
+}
+
+const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
+    return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str();
+}
+
+whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) {
+    return state->result_all[i_segment].tokens[i_token].id;
 }
 
 whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
-    return ctx->result_all[i_segment].tokens[i_token].id;
+    return ctx->state->result_all[i_segment].tokens[i_token].id;
+}
+
+struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) {
+    return state->result_all[i_segment].tokens[i_token];
 }
 
 struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
-    return ctx->result_all[i_segment].tokens[i_token];
+    return ctx->state->result_all[i_segment].tokens[i_token];
+}
+
+float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token) {
+    return state->result_all[i_segment].tokens[i_token].p;
 }
 
 float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
-    return ctx->result_all[i_segment].tokens[i_token].p;
+    return ctx->state->result_all[i_segment].tokens[i_token].p;
 }
 
 // =================================================================================================
@@ -4583,13 +4752,14 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
 
 static void whisper_exp_compute_token_level_timestamps(
         struct whisper_context & ctx,
+          struct whisper_state & state,
                            int   i_segment,
                          float   thold_pt,
                          float   thold_ptsum) {
-    auto & segment = ctx.result_all[i_segment];
+    auto & segment = state.result_all[i_segment];
     auto & tokens  = segment.tokens;
 
-    const int n_samples = ctx.energy.size();
+    const int n_samples = state.energy.size();
 
     if (n_samples == 0) {
         fprintf(stderr, "%s: no signal data available\n", __func__);
@@ -4612,9 +4782,9 @@ static void whisper_exp_compute_token_level_timestamps(
         return;
     }
 
-    auto & t_beg    = ctx.t_beg;
-    auto & t_last   = ctx.t_last;
-    auto & tid_last = ctx.tid_last;
+    auto & t_beg    = state.t_beg;
+    auto & t_last   = state.t_last;
+    auto & tid_last = state.tid_last;
 
     for (int j = 0; j < n; ++j) {
         auto & token = tokens[j];
@@ -4737,15 +4907,15 @@ static void whisper_exp_compute_token_level_timestamps(
             float sum = 0.0f;
 
             for (int k = ss0; k < ss1; k++) {
-                sum += ctx.energy[k];
+                sum += state.energy[k];
             }
 
             const float thold = 0.5*sum/ns;
 
             {
                 int k = s0;
-                if (ctx.energy[k] > thold && j > 0) {
-                    while (k > 0 && ctx.energy[k] > thold) {
+                if (state.energy[k] > thold && j > 0) {
+                    while (k > 0 && state.energy[k] > thold) {
                         k--;
                     }
                     tokens[j].t0 = sample_to_timestamp(k);
@@ -4755,7 +4925,7 @@ static void whisper_exp_compute_token_level_timestamps(
                         s0 = k;
                     }
                 } else {
-                    while (ctx.energy[k] < thold && k < s1) {
+                    while (state.energy[k] < thold && k < s1) {
                         k++;
                     }
                     s0 = k;
@@ -4765,8 +4935,8 @@ static void whisper_exp_compute_token_level_timestamps(
 
             {
                 int k = s1;
-                if (ctx.energy[k] > thold) {
-                    while (k < n_samples - 1 && ctx.energy[k] > thold) {
+                if (state.energy[k] > thold) {
+                    while (k < n_samples - 1 && state.energy[k] > thold) {
                         k++;
                     }
                     tokens[j].t1 = sample_to_timestamp(k);
@@ -4776,7 +4946,7 @@ static void whisper_exp_compute_token_level_timestamps(
                         s1 = k;
                     }
                 } else {
-                    while (ctx.energy[k] < thold && k > s0) {
+                    while (state.energy[k] < thold && k > s0) {
                         k--;
                     }
                     s1 = k;
index 3eb8d08420948ca35533d19ffbd9047a55a78493..3984195dc8c1f93973cc8dd230a112d4ee7988a6 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -66,6 +66,7 @@ extern "C" {
     //
 
     struct whisper_context;
+    struct whisper_state;
 
     typedef int whisper_token;
 
@@ -101,11 +102,20 @@ extern "C" {
     WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
     WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
 
-    // Frees all memory allocated by the model.
-    WHISPER_API void whisper_free(struct whisper_context * ctx);
+    // These are the same as the above, but the internal state of the context is not allocated automatically
+    // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
+    WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model);
+    WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size);
+    WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader);
+
+    WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
+
+    // Frees all allocated memory
+    WHISPER_API void whisper_free      (struct whisper_context * ctx);
+    WHISPER_API void whisper_free_state(struct whisper_state * state);
 
     // Convert RAW PCM audio to log mel spectrogram.
-    // The resulting spectrogram is stored inside the provided whisper context.
+    // The resulting spectrogram is stored inside the default state of the provided whisper context.
     // Returns 0 on success
     WHISPER_API int whisper_pcm_to_mel(
             struct whisper_context * ctx,
@@ -113,17 +123,30 @@ extern "C" {
                                int   n_samples,
                                int   n_threads);
 
-    // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. 
-    // The resulting spectrogram is stored inside the provided whisper context.
+    WHISPER_API int whisper_pcm_to_mel_with_state(
+            struct whisper_context * ctx,
+              struct whisper_state * state,
+                       const float * samples,
+                               int   n_samples,
+                               int   n_threads);
+
+    // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
+    // The resulting spectrogram is stored inside the default state of the provided whisper context.
     // Returns 0 on success
     WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
-        struct whisper_context* ctx,
-        const float* samples,
-        int   n_samples,
-        int   n_threads);
-
-
-    // This can be used to set a custom log mel spectrogram inside the provided whisper context.
+        struct whisper_context * ctx,
+                   const float * samples,
+                           int   n_samples,
+                           int   n_threads);
+
+    WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
+        struct whisper_context * ctx,
+          struct whisper_state * state,
+                   const float * samples,
+                           int   n_samples,
+                           int   n_threads);
+
+    // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
     // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
     // n_mel must be 80
     // Returns 0 on success
@@ -133,7 +156,14 @@ extern "C" {
                                int   n_len,
                                int   n_mel);
 
-    // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
+    WHISPER_API int whisper_set_mel_with_state(
+            struct whisper_context * ctx,
+              struct whisper_state * state,
+                       const float * data,
+                               int   n_len,
+                               int   n_mel);
+
+    // Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
     // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
     // offset can be used to specify the offset of the first frame in the spectrogram.
     // Returns 0 on success
@@ -142,6 +172,12 @@ extern "C" {
                                int   offset,
                                int   n_threads);
 
+    WHISPER_API int whisper_encode_with_state(
+            struct whisper_context * ctx,
+              struct whisper_state * state,
+                               int   offset,
+                               int   n_threads);
+
     // Run the Whisper decoder to obtain the logits and probabilities for the next token.
     // Make sure to call whisper_encode() first.
     // tokens + n_tokens is the provided context for the decoder.
@@ -155,6 +191,14 @@ extern "C" {
                                int   n_past,
                                int   n_threads);
 
+    WHISPER_API int whisper_decode_with_state(
+            struct whisper_context * ctx,
+              struct whisper_state * state,
+               const whisper_token * tokens,
+                               int   n_tokens,
+                               int   n_past,
+                               int   n_threads);
+
     // Convert the provided text into tokens.
     // The tokens pointer must be large enough to hold the resulting tokens.
     // Returns the number of tokens on success, no more than n_max_tokens
@@ -190,17 +234,26 @@ extern "C" {
                                int   n_threads,
                              float * lang_probs);
 
-    WHISPER_API int whisper_n_len          (struct whisper_context * ctx); // mel length
-    WHISPER_API int whisper_n_vocab        (struct whisper_context * ctx);
-    WHISPER_API int whisper_n_text_ctx     (struct whisper_context * ctx);
-    WHISPER_API int whisper_n_audio_ctx    (struct whisper_context * ctx);
-    WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
+    WHISPER_API int whisper_lang_auto_detect_with_state(
+            struct whisper_context * ctx,
+              struct whisper_state * state,
+                               int   offset_ms,
+                               int   n_threads,
+                             float * lang_probs);
+
+    WHISPER_API int whisper_n_len           (struct whisper_context * ctx); // mel length
+    WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
+    WHISPER_API int whisper_n_vocab         (struct whisper_context * ctx);
+    WHISPER_API int whisper_n_text_ctx      (struct whisper_context * ctx);
+    WHISPER_API int whisper_n_audio_ctx     (struct whisper_context * ctx);
+    WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx);
 
     // Token logits obtained from the last call to whisper_decode()
     // The logits for the last token are stored in the last row
     // Rows: n_tokens
     // Cols: n_vocab
-    WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
+    WHISPER_API float * whisper_get_logits           (struct whisper_context * ctx);
+    WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);
 
     // Token Id -> String. Uses the vocabulary in the provided context
     WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
@@ -218,7 +271,7 @@ extern "C" {
     WHISPER_API whisper_token whisper_token_translate (void);
     WHISPER_API whisper_token whisper_token_transcribe(void);
 
-    // Performance information
+    // Performance information from the default state.
     WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
     WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
 
@@ -236,18 +289,19 @@ extern "C" {
     // Text segment callback
     // Called on every newly generated text segment
     // Use the whisper_full_...() functions to obtain the text segments
-    typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
+    typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
 
     // Encoder begin callback
     // If not NULL, called before the encoder starts
     // If it returns false, the computation is aborted
-    typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
+    typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
 
     // Logits filter callback
     // Can be used to modify the logits before sampling
     // If not NULL, called after applying temperature to logits
     typedef void (*whisper_logits_filter_callback)(
             struct whisper_context * ctx,
+              struct whisper_state * state,
           const whisper_token_data * tokens,
                                int   n_tokens,
                              float * logits,
@@ -334,6 +388,7 @@ extern "C" {
     WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
 
     // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
+    // Not thread safe for same context
     // Uses the specified decoding strategy to obtain the text.
     WHISPER_API int whisper_full(
                 struct whisper_context * ctx,
@@ -341,7 +396,16 @@ extern "C" {
                            const float * samples,
                                    int   n_samples);
 
-    // Split the input audio in chunks and process each chunk separately using whisper_full()
+    WHISPER_API int whisper_full_with_state(
+                struct whisper_context * ctx,
+                  struct whisper_state * state,
+            struct whisper_full_params   params,
+                           const float * samples,
+                                   int   n_samples);
+
+    // Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
+    // Result is stored in the default state of the context
+    // Not thread safe if executed in parallel on the same context.
     // It seems this approach can offer some speedup in some cases.
     // However, the transcription accuracy can be worse at the beginning and end of each chunk.
     WHISPER_API int whisper_full_parallel(
@@ -351,33 +415,47 @@ extern "C" {
                                    int   n_samples,
                                    int   n_processors);
 
-    // Number of generated text segments.
+    // Number of generated text segments
     // A segment can be a few words, a sentence, or even a paragraph.
-    WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
+    WHISPER_API int whisper_full_n_segments           (struct whisper_context * ctx);
+    WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state);
 
-    // Language id associated with the current context
+    // Language id associated with the context's default state
     WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
 
-    // Get the start and end time of the specified segment.
-    WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
-    WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
+    // Language id associated with the provided state
+    WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state);
+
+    // Get the start and end time of the specified segment
+    WHISPER_API int64_t whisper_full_get_segment_t0           (struct whisper_context * ctx, int i_segment);
+    WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment);
+
+    WHISPER_API int64_t whisper_full_get_segment_t1           (struct whisper_context * ctx, int i_segment);
+    WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
+
+    // Get the text of the specified segment
+    WHISPER_API const char * whisper_full_get_segment_text           (struct whisper_context * ctx, int i_segment);
+    WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
 
-    // Get the text of the specified segment.
-    WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
+    // Get number of tokens in the specified segment
+    WHISPER_API int whisper_full_n_tokens           (struct whisper_context * ctx, int i_segment);
+    WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment);
 
-    // Get number of tokens in the specified segment.
-    WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment);
+    // Get the token text of the specified token in the specified segment
+    WHISPER_API const char * whisper_full_get_token_text           (struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token);
 
-    // Get the token text of the specified token in the specified segment.
-    WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
-    WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API whisper_token whisper_full_get_token_id           (struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token);
 
-    // Get token data for the specified token in the specified segment.
+    // Get token data for the specified token in the specified segment
     // This contains probabilities, timestamps, etc.
-    WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API whisper_token_data whisper_full_get_token_data           (struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token);
 
-    // Get the probability of the specified token in the specified segment.
-    WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
+    // Get the probability of the specified token in the specified segment
+    WHISPER_API float whisper_full_get_token_p           (struct whisper_context * ctx, int i_segment, int i_token);
+    WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
 
     ////////////////////////////////////////////////////////////////////////////