]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add batched decoding (#1486)
authorGeorgi Gerganov <redacted>
Wed, 15 Nov 2023 14:12:52 +0000 (16:12 +0200)
committerGitHub <redacted>
Wed, 15 Nov 2023 14:12:52 +0000 (16:12 +0200)
* whisper : add whisper_batch

* whisper : move kv_self to whisper_state

* whisper : full batched decoding support

* whisper : fix memory leak in whisper_batch

* whisper : fix mem leak again + remove oboslete function

* whisper : clear kv cache when using whisper_decode API

* whisper : speed-up sampling

* whisper : fix decoders initializer

* bench : add batch size 5 bench

* whisper : add comment about the KV cache size

* whisper : add check for max number of decoders

* whisper : avoid starting sampling threads with bs=1

* whisper : enable beam-search by default

* cuda : sync llama.cpp fixes

examples/bench/bench.cpp
examples/main/main.cpp
extra/bench-all.sh
ggml-cuda.cu
ggml-cuda.h
whisper.cpp
whisper.h

index db1c4e800cd74983035af5e5e343128ef31ddfb5..949e5737167ef662dac68c5868a4cbb3d7b7028e 100644 (file)
@@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
     }
     // heat encoder
     if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
-        fprintf(stderr, "error: failed to encode model: %d\n", ret);
+        fprintf(stderr, "error: failed to encode: %d\n", ret);
         return 4;
     }
 
@@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) {
 
     // prompt heat
     if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
-        fprintf(stderr, "error: failed to encode model: %d\n", ret);
+        fprintf(stderr, "error: failed to decode: %d\n", ret);
         return 4;
     }
 
     // text-generation heat
     if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
-        fprintf(stderr, "error: failed to encode model: %d\n", ret);
+        fprintf(stderr, "error: failed to decode: %d\n", ret);
         return 4;
     }
 
@@ -104,20 +104,30 @@ int whisper_bench_full(const whisper_params & params) {
 
     // actual run
     if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
-        fprintf(stderr, "error: failed to encode model: %d\n", ret);
+        fprintf(stderr, "error: failed to encode: %d\n", ret);
         return 4;
     }
 
-    for (int i = 0; i < 16; i++) {
-        if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
-            fprintf(stderr, "error: failed to encode model: %d\n", ret);
+    // text-generation
+    for (int i = 0; i < 256; i++) {
+        if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
+            fprintf(stderr, "error: failed to decode: %d\n", ret);
             return 4;
         }
     }
 
-    for (int i = 0; i < 256; i++) {
-        if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
-            fprintf(stderr, "error: failed to encode model: %d\n", ret);
+    // batched decoding
+    for (int i = 0; i < 64; i++) {
+        if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
+            fprintf(stderr, "error: failed to decode: %d\n", ret);
+            return 4;
+        }
+    }
+
+    // prompt processing
+    for (int i = 0; i < 16; i++) {
+        if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
+            fprintf(stderr, "error: failed to decode: %d\n", ret);
             return 4;
         }
     }
index e43dfe3f948f916836fb6022f4379355bd5d2bcf..98af5839ca551fe25f15b65d8c37b6c19d0672f1 100644 (file)
@@ -62,8 +62,8 @@ struct whisper_params {
     int32_t progress_step =  5;
     int32_t max_context  = -1;
     int32_t max_len      =  0;
-    int32_t best_of      =  2;
-    int32_t beam_size    = -1;
+    int32_t best_of      = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
+    int32_t beam_size    = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
 
     float word_thold    =  0.01f;
     float entropy_thold =  2.40f;
@@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
             if (params.detect_language) {
                 params.language = "auto";
             }
-            fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
+            fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
                     __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
-                    params.n_threads, params.n_processors,
+                    params.n_threads, params.n_processors, params.beam_size, params.best_of,
                     params.language.c_str(),
                     params.translate ? "translate" : "transcribe",
                     params.tinydiarize ? "tdrz = 1, " : "",
index db042673d698e3872a4467edfede25f41e9711d6..af8f67599a40efa2387ad11c5e49d82d1b65b3f9 100755 (executable)
@@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
     printf "\n"
 fi
 
-printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
-printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
+printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
+printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
 
 for model in "${models[@]}"; do
     # actual run
@@ -56,6 +56,7 @@ for model in "${models[@]}"; do
     # parse the output:
     encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
     decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
+    batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}')
     prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
     system_info=$(echo "$output" | grep "system_info")
     n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
@@ -94,6 +95,6 @@ for model in "${models[@]}"; do
     commit=$(git rev-parse --short HEAD)
 
     if [ $ret -eq 0 ]; then
-        printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
+        printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
     fi
 done
index 058011a48f0da5189223145aab5c4aa26d1b3b3f..c0c9edd56dbc232b060afd62304d04bf4df45be3 100644 (file)
@@ -39,7 +39,6 @@
 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
 #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
 #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
-#define cudaDeviceGetMemPool hipDeviceGetMemPool
 #define cudaDeviceProp hipDeviceProp_t
 #define cudaDeviceSynchronize hipDeviceSynchronize
 #define cudaError_t hipError_t
@@ -49,7 +48,6 @@
 #define cudaEvent_t hipEvent_t
 #define cudaEventDestroy hipEventDestroy
 #define cudaFree hipFree
-#define cudaFreeAsync hipFreeAsync
 #define cudaFreeHost hipHostFree
 #define cudaGetDevice hipGetDevice
 #define cudaGetDeviceCount hipGetDeviceCount
@@ -57,7 +55,6 @@
 #define cudaGetErrorString hipGetErrorString
 #define cudaGetLastError hipGetLastError
 #define cudaMalloc hipMalloc
-#define cudaMallocFromPoolAsync hipMallocFromPoolAsync
 #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
 #define cudaMemcpy hipMemcpy
 #define cudaMemcpy2DAsync hipMemcpy2DAsync
@@ -66,9 +63,6 @@
 #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
 #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
 #define cudaMemcpyKind hipMemcpyKind
-#define cudaMemPool_t hipMemPool_t
-#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold
-#define cudaMemPoolSetAttribute hipMemPoolSetAttribute
 #define cudaMemset hipMemset
 #define cudaMemsetAsync hipMemsetAsync
 #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
@@ -94,6 +88,8 @@
 #define CC_OFFSET_AMD 1000000
 #define CC_RDNA2      (CC_OFFSET_AMD + 1030)
 
+#define GGML_CUDA_MAX_NODES 8192
+
 // define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
 // on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
 // for large computational tasks. the drawback is that this requires some extra amount of VRAM:
@@ -188,11 +184,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     do {                                                                                \
         cudaError_t err_ = (err);                                                       \
         if (err_ != cudaSuccess) {                                                      \
-            int dev_id;                                                                     \
-            cudaGetDevice(&dev_id);                                                         \
+            int id;                                                                     \
+            cudaGetDevice(&id);                                                         \
             fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
                 cudaGetErrorString(err_));                                              \
-            fprintf(stderr, "current device: %d\n", dev_id);                                \
+            fprintf(stderr, "current device: %d\n", id);                                \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
@@ -202,11 +198,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     do {                                                                                \
         cublasStatus_t err_ = (err);                                                    \
         if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
-            int dev_id;                                                                     \
-            cudaGetDevice(&dev_id);                                                         \
+            int id;                                                                     \
+            cudaGetDevice(&id);                                                         \
             fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \
                     err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \
-            fprintf(stderr, "current device: %d\n", dev_id);                                \
+            fprintf(stderr, "current device: %d\n", id);                                \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
@@ -440,6 +436,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_MUL_BLOCK_SIZE 256
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_RELU_BLOCK_SIZE 256
+#define CUDA_SQR_BLOCK_SIZE 256
 #define CUDA_CPY_BLOCK_SIZE 32
 #define CUDA_SCALE_BLOCK_SIZE 256
 #define CUDA_CLAMP_BLOCK_SIZE 256
@@ -472,7 +470,6 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
 
 #define MAX_STREAMS 8
 static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
-static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
 
 struct ggml_tensor_extra_gpu {
     void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@@ -561,6 +558,24 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __global__ void relu_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = fmaxf(x[i], 0);
+}
+
+static __global__ void sqr_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * x[i];
+}
+
 static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
@@ -990,7 +1005,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
 
     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     if (row > nrows) return;
 
     const int num_blocks_per_row = ncols / QK_K;
@@ -1094,7 +1109,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
 
 static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     if (row > nrows) return;
 
     const int num_blocks_per_row = ncols / QK_K;
@@ -1198,7 +1213,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
 
 static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     if (row > nrows) return;
     const int num_blocks_per_row = ncols / QK_K;
     const int ib0 = row*num_blocks_per_row;
@@ -1452,7 +1467,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
 
     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     if (row > nrows) return;
 
     const int num_blocks_per_row = ncols / QK_K;
@@ -4262,7 +4277,7 @@ template <bool need_check> static __global__ void
 
 template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
 static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
 
     if (row >= nrows) {
         return;
@@ -4302,7 +4317,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
 static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
     // qk = quantized weights per x block
     // qr = number of quantized weights per data value in x block
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
 
     if (row >= nrows) {
         return;
@@ -4741,7 +4756,7 @@ static  __global__ void im2col_f32_f16(
         int ofs0, int ofs1, int IW, int IH, int CHW,
         int s0, int s1, int p0, int p1, int d0, int d1) {
     const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
-       const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
+    const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
 
     const int offset_dst =
         (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
@@ -4793,6 +4808,16 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+    relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
+    sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
@@ -4901,7 +4926,8 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4910,7 +4936,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
 static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4919,7 +4945,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
 static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4928,7 +4954,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
 static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4937,7 +4963,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
 static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4947,7 +4973,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
     GGML_ASSERT(ncols % QK_K == 0);
     const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
     const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(32, ny, 1);
     dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -4956,7 +4982,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
     GGML_ASSERT(ncols % QK_K == 0);
     const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(32, ny, 1);
     dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -4965,7 +4991,7 @@ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, f
     GGML_ASSERT(ncols % QK_K == 0);
     const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(32, ny, 1);
     dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -4980,7 +5006,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
     GGML_ASSERT(ncols % QK_K == 0);
     const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(32, ny, 1);
     dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -4988,7 +5014,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
 static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK4_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -4997,7 +5023,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK4_1 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5006,7 +5032,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK5_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5015,7 +5041,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK5_1 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5024,7 +5050,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK8_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5033,7 +5059,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5042,7 +5068,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5051,7 +5077,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5060,7 +5086,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5069,7 +5095,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5088,7 +5114,7 @@ static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cu
 static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<1, 1, convert_f16>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -5825,16 +5851,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
     return ptr;
 }
 
-static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
-    if (g_cudaMemPools[id] == nullptr) {
-        return ggml_cuda_pool_malloc(size, actual_size);
-    }
-    void *ptr;
-    CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
-    *actual_size = size;
-    return ptr;
-}
-
 static void ggml_cuda_pool_free(void * ptr, size_t size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
     int id;
@@ -5852,12 +5868,10 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
     CUDA_CHECK(cudaFree(ptr));
 }
 
+static bool g_cublas_loaded = false;
 
-static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
-    if (g_cudaMemPools[id] == nullptr) {
-        return ggml_cuda_pool_free(ptr, actual_size);
-    }
-    CUDA_CHECK(cudaFreeAsync(ptr, stream));
+bool ggml_cublas_loaded(void) {
+    return g_cublas_loaded;
 }
 
 void ggml_init_cublas() {
@@ -5872,7 +5886,12 @@ void ggml_init_cublas() {
         CUDA_CHECK(cudaDeviceSynchronize());
 #endif
 
-        CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
+        if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) {
+            initialized = true;
+            g_cublas_loaded = false;
+            return;
+        }
+
         GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
         int64_t total_vram = 0;
 #if defined(GGML_CUDA_FORCE_MMQ)
@@ -5914,19 +5933,13 @@ void ggml_init_cublas() {
             // create cublas handle
             CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
             CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
-
-            // configure memory pool
-            cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
-            if (err == cudaSuccess) {
-                size_t treshold = UINT64_MAX;
-                CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
-            }
         }
 
         // configure logging to stdout
         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
 
         initialized = true;
+        g_cublas_loaded = true;
     }
 }
 
@@ -6193,6 +6206,34 @@ inline void ggml_cuda_op_silu(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_relu(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_cuda_op_sqr(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_cuda_op_norm(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6514,7 +6555,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
             size_t ne = row_diff*ne00;
-            src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream);
+            src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
             to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
         }
         const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
@@ -6525,12 +6566,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
             size_t ne = src1_ncols*ne10;
-            src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
+            src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
             to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
         }
         const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
-        size_t dst_f16_as = 0;
-        half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
+        size_t dst_as = 0;
+        half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
 
         const half alpha_f16 = 1.0f;
         const half beta_f16 = 0.0f;
@@ -6548,15 +6589,14 @@ inline void ggml_cuda_op_mul_mat_cublas(
         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
         to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
 
-        if (dst_f16_as != 0) {
-            ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
-        }
+        ggml_cuda_pool_free(dst_f16, dst_as);
 
         if (src0_as != 0) {
-            ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream);
+            ggml_cuda_pool_free(src0_as_f16, src0_as);
         }
+
         if (src1_as != 0) {
-            ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream);
+            ggml_cuda_pool_free(src1_as_f16, src1_as);
         }
     }
     else {
@@ -6566,7 +6606,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
         if (src0->type != GGML_TYPE_F32) {
             const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
             GGML_ASSERT(to_fp32_cuda != nullptr);
-            src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT
+            src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
             to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
         }
         const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
@@ -6583,7 +6623,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
                     &beta,  dst_dd_i,   ldc));
 
         if (src0_as != 0) {
-            ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream);
+            ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
         }
     }
 
@@ -7008,6 +7048,8 @@ static void ggml_cuda_op_mul_mat(
     int64_t  row_low[GGML_CUDA_MAX_DEVICES];
     int64_t row_high[GGML_CUDA_MAX_DEVICES];
 
+    int used_devices = 0;
+
     for (int64_t id = 0; id < g_device_count; ++id) {
         // by default, use all rows
         row_low[id]  = 0;
@@ -7035,6 +7077,8 @@ static void ggml_cuda_op_mul_mat(
             continue;
         }
 
+        used_devices++;
+
         const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
         const bool  dst_on_device =  dst->backend == GGML_BACKEND_GPU && id == g_main_device;
 
@@ -7045,22 +7089,21 @@ static void ggml_cuda_op_mul_mat(
             src0_dd[id] = (char *) src0_extra->data_device[id];
         } else {
             const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
-            src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream);
+            src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
         }
 
         if (src1_on_device && src1_is_contiguous) {
             src1_ddf[id] = (float *) src1_extra->data_device[id];
         } else {
-            src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream);
+            src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
         }
 
         if (convert_src1_to_q8_1) {
-            const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
-            src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
+            src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
 
             if (src1_on_device && src1_is_contiguous) {
                 quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
-                // CUDA_CHECK(cudaGetLastError());
+                CUDA_CHECK(cudaGetLastError());
             }
         }
 
@@ -7068,18 +7111,18 @@ static void ggml_cuda_op_mul_mat(
             dst_dd[id] = (float *) dst_extra->data_device[id];
         } else {
             const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
-            dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id,  stream);
+            dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
         }
     }
 
     // if multiple devices are used they need to wait for the main device
     // here an event is recorded that signals that the main device has finished calculating the input data
-    if (split && g_device_count > 1) {
+    if (split && used_devices > 1) {
         CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
     }
 
-    const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
+    const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
     for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
         const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
         const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
@@ -7194,6 +7237,27 @@ static void ggml_cuda_op_mul_mat(
         }
     }
 
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
+            continue;
+        }
+        CUDA_CHECK(ggml_cuda_set_device(id));
+
+        // free buffers again when done
+        if (src0_as[id] > 0) {
+            ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
+        }
+        if (src1_asf[id] > 0) {
+            ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
+        }
+        if (src1_asq[id] > 0) {
+            ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
+        }
+        if (dst_as[id] > 0) {
+            ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
+        }
+    }
+
     // main device waits for all other devices to be finished
     if (split && g_device_count > 1) {
         int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
@@ -7201,6 +7265,9 @@ static void ggml_cuda_op_mul_mat(
 
         CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         for (int64_t id = 0; id < g_device_count; ++id) {
+            if (row_low[id] == row_high[id]) {
+                continue;
+            }
             for (int64_t is = 0; is < is_max; ++is) {
                 CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
             }
@@ -7211,21 +7278,6 @@ static void ggml_cuda_op_mul_mat(
         CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         CUDA_CHECK(cudaDeviceSynchronize());
     }
-
-    for (int64_t id = 0; id < g_device_count; ++id) {
-        if (src0_as[id] > 0) {
-            ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
-        }
-        if (src1_asf[id] > 0) {
-            ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
-        }
-        if (src1_asq[id] > 0) {
-            ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
-        }
-        if (dst_as[id] > 0) {
-            ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
-        }
-    }
 }
 
 static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7252,6 +7304,14 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
 }
 
+static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
+}
+
+static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
+}
+
 static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
 }
@@ -7261,6 +7321,8 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
 }
 
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    if (!g_cublas_loaded) return false;
+
     const int64_t ne10 = src1->ne[0];
 
     const int64_t ne0 = dst->ne[0];
@@ -7412,11 +7474,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     GGML_ASSERT(to_fp16_cuda != nullptr);
 
     size_t src1_as = 0;
-    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream);
+    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
     to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
 
     size_t dst_as = 0;
-    half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream);
+    half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
 
     GGML_ASSERT(ne12 % ne02 == 0);
     GGML_ASSERT(ne13 % ne03 == 0);
@@ -7470,8 +7532,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         size_t ptrs_src_s = 0;
         size_t ptrs_dst_s = 0;
 
-        ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream);
-        ptrs_dst = (      void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream);
+        ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
+        ptrs_dst = (      void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
 
         dim3 block_dims(ne13, ne12);
         k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
@@ -7484,6 +7546,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
                 dst->nb[2], dst->nb[3],
                 r2, r3);
         CUDA_CHECK(cudaGetLastError());
+
         CUBLAS_CHECK(
         cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
@@ -7495,30 +7558,29 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
         if (ptrs_src_s != 0) {
-            ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream);
+            ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
         }
         if (ptrs_dst_s != 0) {
-            ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream);
+            ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
         }
     }
 #endif
 
     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
     to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
-    if (src1_as != 0) {
-        ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream);
-    }
-    if (dst_as != 0) {
-        ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
-    }
+
+    ggml_cuda_pool_free(src1_as_f16, src1_as);
+    ggml_cuda_pool_free(dst_f16, dst_as);
 }
 
 static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool all_on_device =
-        (src0->backend == GGML_BACKEND_GPU) &&
+        (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
         (src1->backend == GGML_BACKEND_GPU) &&
         ( dst->backend == GGML_BACKEND_GPU);
 
+    const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
+
     int64_t min_compute_capability = INT_MAX;
     for (int64_t id = 0; id < g_device_count; ++id) {
         if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
@@ -7540,13 +7602,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+    if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
         // KQ single-batch
         ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
-    } else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+    } else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
         // KQV single-batch
         ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
-    } else if (all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
+    } else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
         // KQ + KQV multi-batch
         ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
     } else if (src0->type == GGML_TYPE_F32) {
@@ -7667,7 +7729,7 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
 }
 
-void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
 }
 
@@ -7782,11 +7844,11 @@ static size_t g_temp_tensor_extra_index = 0;
 
 static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
     if (g_temp_tensor_extras == nullptr) {
-        g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
+        g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
     }
 
     size_t alloc_index = g_temp_tensor_extra_index;
-    g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
+    g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
     ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
     memset(extra, 0, sizeof(*extra));
 
@@ -7953,6 +8015,8 @@ void ggml_cuda_free_scratch() {
 }
 
 bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+    if (!g_cublas_loaded) return false;
+
     ggml_cuda_func_t func;
     const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
         || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
@@ -7995,6 +8059,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
                 case GGML_UNARY_OP_SILU:
                     func = ggml_cuda_silu;
                     break;
+                case GGML_UNARY_OP_RELU:
+                    func = ggml_cuda_relu;
+                    break;
                 default:
                     return false;
             } break;
@@ -8013,6 +8080,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_SCALE:
             func = ggml_cuda_scale;
             break;
+        case GGML_OP_SQR:
+            func = ggml_cuda_sqr;
+            break;
         case GGML_OP_CLAMP:
             if (!any_on_device) {
                 return false;
@@ -8105,11 +8175,11 @@ struct ggml_backend_buffer_context_cuda {
 
     ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
         if (temp_tensor_extras == nullptr) {
-            temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
+            temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
         }
 
         size_t alloc_index = temp_tensor_extra_index;
-        temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
+        temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
         ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
         memset(extra, 0, sizeof(*extra));
 
index 57adc9cf34bc5bc4fae4d576b4c2b1572364b8ff..528e66c33a20738ce185744ff5780203c869e4ad 100644 (file)
@@ -17,7 +17,12 @@ extern "C" {
 
 #define GGML_CUDA_MAX_DEVICES       16
 
+// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
 GGML_API void   ggml_init_cublas(void);
+
+// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
+GGML_API bool   ggml_cublas_loaded(void);
+
 GGML_API void * ggml_cuda_host_malloc(size_t size);
 GGML_API void   ggml_cuda_host_free(void * ptr);
 
index c0e91152703b1ead197e2f710dbe5eb1e0858ec7..a3e0fbd007769130e235cb33f661aaf5fe51558a 100644 (file)
@@ -20,6 +20,7 @@
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
 
+#include <atomic>
 #include <algorithm>
 #include <cassert>
 #define _USE_MATH_DEFINES
@@ -147,7 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
 
 //#define WHISPER_USE_FLASH_ATTN
 //#define WHISPER_USE_FLASH_FF
-#define WHISPER_MAX_DECODERS 16
+#define WHISPER_MAX_DECODERS 8
 #define WHISPER_MAX_NODES 4096
 
 //
@@ -406,6 +407,121 @@ struct whisper_segment {
     bool speaker_turn_next;
 };
 
+struct whisper_batch {
+    int32_t n_tokens;
+
+    whisper_token  *  token;
+    whisper_pos    *  pos;
+    int32_t        *  n_seq_id;
+    whisper_seq_id ** seq_id;   // null terminated
+    int8_t         *  logits;
+};
+
+static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) {
+    whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, };
+
+    batch.token    = (whisper_token *  ) malloc(sizeof(whisper_token)    * (n_tokens));
+    batch.pos      = (whisper_pos *)     malloc(sizeof(whisper_pos)      * (n_tokens));
+    batch.n_seq_id = (int32_t *)         malloc(sizeof(int32_t)          * (n_tokens));
+    batch.seq_id   = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1));
+    for (int i = 0; i < n_tokens; ++i) {
+        batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id)   * n_seq_max);
+    }
+    batch.seq_id[n_tokens] = nullptr;
+    batch.logits   = (int8_t *)          malloc(sizeof(int8_t)           * n_tokens);
+
+    return batch;
+}
+
+static void whisper_batch_free(struct whisper_batch batch) {
+    if (batch.token)    free(batch.token);
+    if (batch.pos)      free(batch.pos);
+    if (batch.n_seq_id) free(batch.n_seq_id);
+    if (batch.seq_id) {
+        for (int i = 0; batch.seq_id[i]; ++i) {
+            free(batch.seq_id[i]);
+        }
+        free(batch.seq_id);
+    }
+    if (batch.logits)   free(batch.logits);
+}
+
+static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) {
+    batch.n_tokens = n_tokens;
+    for (int i = 0; i < n_tokens; ++i) {
+        if (tokens) {
+            batch.token[i] = tokens[i];
+        }
+        batch.pos     [i]    = n_past + i;
+        batch.n_seq_id[i]    = 1;
+        batch.seq_id  [i][0] = seq_id;
+        batch.logits  [i]    = 0;
+    }
+    batch.logits[n_tokens - 1] = 1;
+}
+
+// replace std::pair by using customized pair struct (reason: std::pair is very slow)
+template<typename A, typename B>
+struct whisper_pair {
+    A first;
+    B second;
+
+    // Define a constructor that takes two arguments.
+    whisper_pair(const A& a, const B& b) : first(a), second(b) {}
+    // Define a constructor that takes no argument.
+    whisper_pair() : first(A()), second(B()) {}
+};
+
+// ggml_allocr wrapper for whisper usage
+struct whisper_allocr {
+    ggml_allocr * alloc = nullptr;
+
+    std::vector<uint8_t> meta;
+
+    ggml_backend_buffer_t buffer;
+};
+
+static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
+    return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
+}
+
+// measure the memory usage of a graph and prepare the allocr's internal data buffer
+static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
+    auto & alloc  = allocr.alloc;
+    auto & meta   = allocr.meta;
+
+    alloc = ggml_allocr_new_measure_from_backend(backend);
+
+    meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
+
+    ggml_allocr_alloc_graph(alloc, get_graph());
+}
+
+static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
+    if (allocr.alloc == nullptr) {
+        // this can be null if we use external encoder like CoreML or OpenVINO
+        return;
+    }
+
+    auto & alloc  = allocr.alloc;
+    auto & buffer = allocr.buffer;
+
+    size_t size = ggml_allocr_max_size(alloc);
+
+    ggml_allocr_free(alloc);
+
+    buffer = ggml_backend_alloc_buffer(backend, size);
+    alloc = ggml_allocr_new_from_buffer(buffer);
+}
+
+static void whisper_allocr_free(struct whisper_allocr & allocr) {
+    if (allocr.alloc) {
+        ggml_allocr_free(allocr.alloc);
+        ggml_backend_buffer_free(allocr.buffer);
+        allocr.alloc = nullptr;
+    }
+}
+
 // medium
 // hparams: {
 // 'n_mels': 80,
@@ -523,15 +639,31 @@ struct whisper_layer_decoder {
     struct ggml_tensor * mlp_1_b;
 };
 
+struct whisper_kv_cell {
+    whisper_pos pos = -1;
+
+    std::set<whisper_seq_id> seq_id;
+
+    bool has_seq_id(const whisper_seq_id & id) const {
+        return seq_id.find(id) != seq_id.end();
+    }
+};
+
 struct whisper_kv_cache {
+    uint32_t head = 0;
+    uint32_t size = 0;
+
+    // computed before each graph build
+    uint32_t n = 0;
+
+    std::vector<whisper_kv_cell> cells;
+
     struct ggml_tensor * k;
     struct ggml_tensor * v;
 
     struct ggml_context * ctx;
 
     ggml_backend_buffer_t buffer;
-
-    int n; // number of tokens currently in the cache
 };
 
 struct whisper_model {
@@ -585,11 +717,11 @@ struct whisper_partial_utf8 {
 };
 
 struct whisper_grammar {
-    /*const*/ std::vector<std::vector<whisper_grammar_element>>   rules;
-    std::vector<std::vector<const whisper_grammar_element *>> stacks;
+    /*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
+    std::vector<std::vector<const whisper_grammar_element *>>   stacks;
 
     // buffer for partially generated UTF-8 sequence from accepted tokens
-    whisper_partial_utf8                                      partial_utf8;
+    whisper_partial_utf8 partial_utf8;
 };
 
 struct whisper_grammar_candidate {
@@ -613,15 +745,13 @@ struct whisper_sequence {
 
 // TAGS: WHISPER_DECODER_INIT
 struct whisper_decoder {
-    // each decoder keeps its own KV-cache
-    whisper_kv_cache kv_self;
-
     // the currently generated sequence of tokens
     whisper_sequence sequence;
 
     // grammar parse state of generated sequence of tokens
     whisper_grammar  grammar;
 
+    int i_batch;    // the index of the token in the current batch
     int seek_delta; // the window shift found so far based on the decoded timestamp tokens
 
     bool failed;    // has the current segment failed to decode?
@@ -633,100 +763,40 @@ struct whisper_decoder {
     std::vector<float> logits;
     std::vector<float> logprobs;
 
-    std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
-};
-
-// replace std::pair by using customized pair struct (reason: std::pair is very slow)
-template<typename A, typename B>
-struct whisper_pair {
-    A first;
-    B second;
-
-    // Define a constructor that takes two arguments.
-    whisper_pair(const A& a, const B& b) : first(a), second(b) {}
-    // Define a constructor that takes no argument.
-    whisper_pair() : first(A()), second(B()) {}
-};
-
-// beam-search helpers
-struct kv_buf {
-    std::vector<uint8_t> k;
-    std::vector<uint8_t> v;
-};
-
-// ggml_allocr wrapper for whisper usage
-struct whisper_allocr {
-    ggml_allocr * alloc = nullptr;
-
-    std::vector<uint8_t> meta;
+    // work container used to avoid memory allocations
+    std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
 
-    ggml_backend_buffer_t buffer;
+    mutable std::mt19937 rng; // used for sampling at t > 0.0
 };
 
-static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
-    return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
-}
-
-// measure the memory usage of a graph and prepare the allocr's internal data buffer
-static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
-    auto & alloc  = allocr.alloc;
-    auto & meta   = allocr.meta;
-
-    alloc = ggml_allocr_new_measure_from_backend(backend);
-
-    meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
-
-    ggml_allocr_alloc_graph(alloc, get_graph());
-}
-
-static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
-    if (allocr.alloc == nullptr) {
-        // this can be null if we use external encoder like CoreML or OpenVINO
-        return;
-    }
-
-    auto & alloc  = allocr.alloc;
-    auto & buffer = allocr.buffer;
-
-    size_t size = ggml_allocr_max_size(alloc);
-
-    ggml_allocr_free(alloc);
-
-    buffer = ggml_backend_alloc_buffer(backend, size);
-    alloc = ggml_allocr_new_from_buffer(buffer);
-}
-
-static void whisper_allocr_free(struct whisper_allocr & allocr) {
-    if (allocr.alloc) {
-        ggml_allocr_free(allocr.alloc);
-        ggml_backend_buffer_free(allocr.buffer);
-        allocr.alloc = nullptr;
-    }
-}
-
 struct whisper_state {
     int64_t t_sample_us = 0;
     int64_t t_encode_us = 0;
     int64_t t_decode_us = 0;
+    int64_t t_batchd_us = 0;
     int64_t t_prompt_us = 0;
     int64_t t_mel_us = 0;
 
     int32_t n_sample = 0; // number of tokens sampled
     int32_t n_encode = 0; // number of encoder calls
-    int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
-    int32_t n_prompt = 0; // number of decoder calls with n_tokens >  1 (prompt encoding)
+    int32_t n_decode = 0; // number of decoder calls with n_tokens == 1  (text-generation)
+    int32_t n_batchd = 0; // number of decoder calls with n_tokens <  16 (batch decoding)
+    int32_t n_prompt = 0; // number of decoder calls with n_tokens >  1  (prompt encoding)
     int32_t n_fail_p = 0; // number of logprob threshold failures
     int32_t n_fail_h = 0; // number of entropy threshold failures
 
+    // unified self-attention KV cache for all decoders
+    whisper_kv_cache kv_self;
+
     // cross-attention KV cache for the decoders
     // shared between all decoders
     whisper_kv_cache kv_cross;
+
     whisper_mel mel;
 
-    whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
+    whisper_batch batch;
 
-    // buffer for swapping KV caches between decoders during beam-search
-    std::vector<kv_buf> kv_swap_bufs;
+    whisper_decoder decoders[WHISPER_MAX_DECODERS];
 
     ggml_backend_t backend = nullptr;
 
@@ -742,8 +812,9 @@ struct whisper_state {
     struct ggml_tensor * embd_conv = nullptr;
     struct ggml_tensor * embd_enc  = nullptr;
 
-    // helper for GPU offloading
+    // helpers for GPU offloading
     std::vector<float> inp_mel;
+    std::vector<float> inp_mask;
 
     // decode output (2-dimensional array: [n_tokens][n_vocab])
     std::vector<float> logits;
@@ -751,11 +822,6 @@ struct whisper_state {
     std::vector<whisper_segment> result_all;
     std::vector<whisper_token>   prompt_past;
 
-    // work container used to avoid memory allocations
-    std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
-
-    mutable std::mt19937 rng; // used for sampling at t > 0.0
-
     int lang_id = 0; // english by default
 
     std::string path_model; // populated by whisper_init_from_file_with_params()
@@ -831,6 +897,12 @@ static bool kv_cache_init(
         /*.no_alloc   =*/ true,
     };
 
+    cache.head = 0;
+    cache.size = n_ctx;
+
+    cache.cells.clear();
+    cache.cells.resize(n_ctx);
+
     cache.ctx = ggml_init(params);
 
     if (!cache.ctx) {
@@ -858,54 +930,129 @@ static bool kv_cache_init(
     return true;
 }
 
-// TODO: remove after batched decoding
-static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) {
-    WHISPER_ASSERT(cache.ctx);
+static void kv_cache_free(struct whisper_kv_cache & cache) {
+    if (cache.ctx) {
+        ggml_free(cache.ctx);
+        ggml_backend_buffer_free(cache.buffer);
+        cache.ctx = nullptr;
+    }
+}
 
-    const int n_elements = ggml_nelements(cache.k);
-    WHISPER_ASSERT(n_elements == ggml_nelements(cache.v));
+static bool whisper_kv_cache_find_slot(
+           struct whisper_kv_cache & cache,
+        const struct whisper_batch & batch) {
+    const uint32_t n_ctx    = cache.size;
+    const uint32_t n_tokens = batch.n_tokens;
 
-    const ggml_type wtype = cache.k->type;
-    WHISPER_ASSERT(wtype == cache.v->type);
+    if (n_tokens > n_ctx) {
+        WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
+        return false;
+    }
 
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ 2*ggml_tensor_overhead(),
-        /*.mem_buffer =*/ nullptr,
-        /*.no_alloc   =*/ true,
-    };
+    uint32_t n_tested = 0;
 
-    cache.ctx = ggml_init(params);
+    while (true) {
+        if (cache.head + n_tokens > n_ctx) {
+            n_tested += n_ctx - cache.head;
+            cache.head = 0;
+            continue;
+        }
 
-    if (!cache.ctx) {
-        WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
-        return false;
+        bool found = true;
+        for (uint32_t i = 0; i < n_tokens; i++) {
+            if (cache.cells[cache.head + i].pos >= 0) {
+                found = false;
+                cache.head += i + 1;
+                n_tested   += i + 1;
+                break;
+            }
+        }
+
+        if (found) {
+            break;
+        }
+
+        if (n_tested >= n_ctx) {
+            //WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
+            return false;
+        }
     }
 
-    cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
-    cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+    for (uint32_t i = 0; i < n_tokens; i++) {
+        cache.cells[cache.head + i].pos = batch.pos[i];
 
-    const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v);
+        for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
+            cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
+        }
+    }
 
-    cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes);
+    return true;
+}
 
-    // allocate the tensors into the backend buffer
-    {
-        ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer);
+// find how many cells are currently in use
+static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) {
+    for (uint32_t i = cache.size - 1; i > 0; --i) {
+        if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
+            return i + 1;
+        }
+    }
 
-        ggml_allocr_alloc(alloc, cache.k);
-        ggml_allocr_alloc(alloc, cache.v);
+    return 1;
+}
 
-        ggml_allocr_free(alloc);
+static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
+    for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
+        cache.cells[i].pos = -1;
+        cache.cells[i].seq_id.clear();
     }
+    cache.head = 0;
+}
 
-    return true;
+static void whisper_kv_cache_seq_rm(
+        struct whisper_kv_cache & cache,
+                 whisper_seq_id   seq_id,
+                    whisper_pos   p0,
+                    whisper_pos   p1) {
+    uint32_t new_head = cache.size;
+
+    if (p0 < 0) p0 = 0;
+    if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
+
+    for (uint32_t i = 0; i < cache.size; ++i) {
+        if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
+            if (seq_id < 0) {
+                cache.cells[i].seq_id.clear();
+            } else if (cache.cells[i].has_seq_id(seq_id)) {
+                cache.cells[i].seq_id.erase(seq_id);
+            } else {
+                continue;
+            }
+            if (cache.cells[i].seq_id.empty()) {
+                cache.cells[i].pos = -1;
+                if (new_head == cache.size) new_head = i;
+            }
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != cache.size) cache.head = new_head;
 }
 
-static void kv_cache_free(struct whisper_kv_cache & cache) {
-    if (cache.ctx) {
-        ggml_free(cache.ctx);
-        ggml_backend_buffer_free(cache.buffer);
-        cache.ctx = nullptr;
+static void whisper_kv_cache_seq_cp(
+        struct whisper_kv_cache & cache,
+                 whisper_seq_id   seq_id_src,
+                 whisper_seq_id   seq_id_dst,
+                    whisper_pos   p0,
+                    whisper_pos   p1) {
+    if (p0 < 0) p0 = 0;
+    if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
+
+    cache.head = 0;
+
+    for (uint32_t i = 0; i < cache.size; ++i) {
+        if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
+            cache.cells[i].seq_id.insert(seq_id_dst);
+        }
     }
 }
 
@@ -914,7 +1061,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
 
     // initialize the backends
 #ifdef GGML_USE_CUBLAS
-    if (params.use_gpu) {
+    if (params.use_gpu && ggml_cublas_loaded()) {
         WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
         backend_gpu = ggml_backend_cuda_init();
         if (!backend_gpu) {
@@ -1116,6 +1263,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                     word = "[_EOT_]";
                 } else if (i == vocab.token_sot) {
                     word = "[_SOT_]";
+                } else if (i == vocab.token_translate) {
+                    word = "[_TRANSLATE_]";
+                } else if (i == vocab.token_transcribe) {
+                    word = "[_TRANSCRIBE_]";
                 } else if (i == vocab.token_solm) {
                     word = "[_SOLM_]";
                 } else if (i == vocab.token_prev) {
@@ -1126,6 +1277,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                     word = "[_NOT_]";
                 } else if (i == vocab.token_beg) {
                     word = "[_BEG_]";
+                } else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
+                    word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
                 } else {
                     word = "[_extra_token_" + std::to_string(i) + "]";
                 }
@@ -2031,26 +2184,28 @@ static bool whisper_encode_internal(
 static struct ggml_cgraph * whisper_build_graph_decoder(
          whisper_context & wctx,
          whisper_state   & wstate,
-         whisper_decoder & decoder,
-     const whisper_token * tokens,
-                   int   n_tokens,
-                   int   n_past) {
+     const whisper_batch & batch) {
     const auto & model   = wctx.model;
     const auto & hparams = model.hparams;
 
-    auto & kv_self = decoder.kv_self;
+    auto & kv_self = wstate.kv_self;
 
     WHISPER_ASSERT(!!kv_self.ctx);
 
-    const int n_ctx   = hparams.n_text_ctx;
+    ggml_allocr * alloc = wstate.alloc_decode.alloc;
+
+    const int n_ctx   = kv_self.size;
     const int n_state = hparams.n_text_state;
     const int n_head  = hparams.n_text_head;
     const int n_layer = hparams.n_text_layer;
 
-    const int N = n_tokens;
-    const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
+    const int n_tokens    = batch.n_tokens;
+    const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
 
-    //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
+    const int32_t n_kv     = ggml_allocr_is_measure(alloc) ? n_ctx            : kv_self.n;
+    const int32_t kv_head  = ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
+
+    //WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
 
     struct ggml_init_params params = {
         /*.mem_size   =*/ wstate.alloc_decode.meta.size(),
@@ -2062,21 +2217,19 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
     ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
 
-    ggml_allocr * alloc = wstate.alloc_decode.alloc;
-
-    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
     ggml_allocr_alloc(alloc, embd);
 
     if (!ggml_allocr_is_measure(alloc)) {
-        ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd));
+        ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd));
     }
 
-    struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
     ggml_allocr_alloc(alloc, position);
 
     if (!ggml_allocr_is_measure(alloc)) {
-        for (int i = 0; i < N; ++i) {
-            const int32_t val = n_past + i;
+        for (int i = 0; i < n_tokens; ++i) {
+            const int32_t val = batch.pos[i];
             ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
         }
     }
@@ -2089,6 +2242,31 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
         ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
     }
 
+    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+    ggml_allocr_alloc(alloc, KQ_mask);
+
+    if (!ggml_allocr_is_measure(alloc)) {
+        wstate.inp_mask.resize(n_kv*n_tokens);
+
+        float * data = wstate.inp_mask.data();
+        memset(data, 0, ggml_nbytes(KQ_mask));
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                const whisper_pos    pos    = batch.pos[j];
+                const whisper_seq_id seq_id = batch.seq_id[j][0];
+
+                for (int i = 0; i < n_kv; ++i) {
+                    if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
+                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
+                    }
+                }
+            }
+        }
+
+        ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
+    }
+
     // token encoding + position encoding
     struct ggml_tensor * cur =
         ggml_add(ctx0,
@@ -2141,12 +2319,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
                             Vcur,
                             layer.attn_v_b);
 
-                Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
+                Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
 
-                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
-                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
+                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
+                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
                         (   n_ctx)*ggml_element_size(kv_self.v),
-                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
 
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
@@ -2156,12 +2334,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
+                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
                 ggml_view_3d(ctx0, kv_self.k,
-                        n_state/n_head, n_past + N, n_head,
+                        n_state/n_head, n_kv, n_head,
                         ggml_element_size(kv_self.k)*n_state,
                         ggml_element_size(kv_self.k)*n_state/n_head,
                         ggml_element_size(kv_self.k)*n_state*n_ctx*il);
@@ -2171,16 +2349,17 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
             //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
 
-            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
+            //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
+            struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask);
 
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
 
             struct ggml_tensor * V =
                 ggml_view_3d(ctx0, kv_self.v,
-                        n_past + N, n_state/n_head, n_head,
+                        n_kv, n_state/n_head, n_head,
                         n_ctx*ggml_element_size(kv_self.v),
                         n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
-                        il*n_ctx*ggml_element_size(kv_self.v)*n_state);
+                        n_ctx*ggml_element_size(kv_self.v)*n_state*il);
 
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
@@ -2188,7 +2367,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
             cur = ggml_cpy(ctx0,
                     KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
         }
 
         // projection
@@ -2232,33 +2411,33 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
             // Kcross is already scaled
             struct ggml_tensor * Kcross =
                 ggml_view_3d(ctx0, wstate.kv_cross.k,
-                        n_state/n_head, M, n_head,
+                        n_state/n_head, n_audio_ctx, n_head,
                         ggml_element_size(wstate.kv_cross.k)*n_state,
                         ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
-                        ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
+                        ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
 
             //struct ggml_tensor * Vcross =
             //    ggml_reshape_3d(ctx0,
-            //            ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
-            //            n_state/n_head, n_head, M);
+            //            ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
+            //            n_state/n_head, n_head, n_audio_ctx);
 
             //struct ggml_tensor * V_trans =
             //    ggml_cpy(ctx0,
             //            ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
-            //            ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
+            //            ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
 
             struct ggml_tensor * V =
                 ggml_view_3d(ctx0, wstate.kv_cross.v,
-                        M, n_state/n_head, n_head,
-                        M*ggml_element_size(wstate.kv_cross.v),
-                        M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
-                        il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
+                        n_audio_ctx, n_state/n_head, n_head,
+                        n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
+                        n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
+                        n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
 
             // ------
 
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
+                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
                         0, 2, 1, 3);
 
             // K * Q
@@ -2279,10 +2458,10 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            // cur = KQV_merged.contiguous().view(n_state, N)
+            // cur = KQV_merged.contiguous().view(n_state, n_tokens)
             cur = ggml_cpy(ctx0,
                     KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
         }
 
         // projection
@@ -2354,9 +2533,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
     }
 
     // compute logits only for the last token
-    // comment this line to compute logits for all tokens
+    // comment this line to compute logits for all n_tokens
     // might be useful in the future
-    cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
+    //cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
 
     struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
 
@@ -2380,10 +2559,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 static bool whisper_decode_internal(
         whisper_context & wctx,
           whisper_state & wstate,
-        whisper_decoder & decoder,
-    const whisper_token * tokens,
-              const int   n_tokens,
-              const int   n_past,
+    const whisper_batch & batch,
               const int   n_threads,
  whisper_abort_callback   abort_callback,
                    void * abort_callback_data) {
@@ -2392,19 +2568,33 @@ static bool whisper_decode_internal(
     const auto & model   = wctx.model;
     const auto & hparams = model.hparams;
 
-    const int n_vocab = hparams.n_vocab;
+    const int n_vocab  = hparams.n_vocab;
+    const int n_tokens = batch.n_tokens;
 
     auto & logits_out = wstate.logits;
 
     struct ggml_tensor * logits;
 
+    // find KV slot for the batch
+    {
+        auto & kv_self = wstate.kv_self;
+
+        if (!whisper_kv_cache_find_slot(kv_self, batch)) {
+            return false;
+        }
+
+        kv_self.n = whisper_kv_cache_cell_max(kv_self);
+        //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
+        //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
+    }
+
     // decoder
     {
         auto & alloc = wstate.alloc_decode.alloc;
 
         ggml_allocr_reset(alloc);
 
-        ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
+        ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
 
         ggml_allocr_alloc_graph(alloc, gf);
 
@@ -2413,17 +2603,15 @@ static bool whisper_decode_internal(
         ggml_graph_compute_helper(wstate.backend, gf, n_threads);
     }
 
-    // extract logits for all N tokens
-    //logits_out.resize(n_tokens*n_vocab);
-    //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
-    //ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab);
-
-    // extract logits only for the last token
-    logits_out.resize(n_vocab);
-    //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
-    ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
+    logits_out.resize(n_tokens*n_vocab);
+    for (int i = 0; i < n_tokens; i++) {
+        if (batch.logits[i] == 0) {
+            continue;
+        }
+        ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
+    }
 
-    if (n_tokens > 1) {
+    if (batch.n_tokens > 1) {
         //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
         //        ggml_used_mem(ctx0)/1024.0/1024.0,
         //        wstate.get_buf_max_mem(0)/1024.0/1024.0,
@@ -2432,18 +2620,20 @@ static bool whisper_decode_internal(
         //        wstate.get_buf_max_mem(3)/1024.0/1024.0);
     }
 
-    if (n_tokens == 1) {
+    if (batch.n_tokens == 1) {
         wstate.t_decode_us += ggml_time_us() - t_start_us;
         wstate.n_decode++;
+    } else if (batch.n_tokens < 16) {
+        wstate.t_batchd_us += ggml_time_us() - t_start_us;
+        wstate.n_batchd += n_tokens;
     } else {
         wstate.t_prompt_us += ggml_time_us() - t_start_us;
-        wstate.n_prompt++;
+        wstate.n_prompt += n_tokens;
     }
 
     return !(abort_callback && abort_callback(abort_callback_data));
 }
 
-
 //  500 -> 00:05.000
 // 6000 -> 01:00.000
 static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2855,14 +3045,18 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     state->backend = whisper_backend_init(ctx->params);
 
-    if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
+    // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
+    // in theory, there can be a case where this is not enough, but in practice it should always be enough
+    const int factor = 3;
+
+    if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
         WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
         delete state;
         return nullptr;
     }
 
     {
-        const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
+        const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v);
         WHISPER_LOG_INFO("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
     }
 
@@ -2897,14 +3091,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
 
-    state->logits_id.reserve(ctx->model.hparams.n_vocab);
+    state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
 
     // TAGS: WHISPER_DECODER_INIT
     state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
 
-    state->decoders[0].probs.reserve   (ctx->vocab.n_vocab);
-    state->decoders[0].logits.reserve  (ctx->vocab.n_vocab);
-    state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
+    state->decoders[0].probs.reserve    (ctx->vocab.n_vocab);
+    state->decoders[0].logits.reserve   (ctx->vocab.n_vocab);
+    state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab);
+    state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab);
+
+    state->decoders[0].rng = std::mt19937(0);
 
     // conv allocator
     {
@@ -2946,7 +3143,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
                     const int n_tokens = hparams.n_text_ctx;
                     const int n_past   = 0;
 
-                    return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
+                    whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
+
+                    return whisper_build_graph_decoder(*ctx, *state, state->batch);
                 });
 
         WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
@@ -2957,8 +3156,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
     whisper_allocr_graph_realloc(state->alloc_cross,  ctx->backend);
     whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
 
-    state->rng = std::mt19937(0);
-
     return state;
 }
 
@@ -3183,12 +3380,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
 void whisper_free_state(struct whisper_state * state)
 {
     if (state) {
+        kv_cache_free(state->kv_self);
         kv_cache_free(state->kv_cross);
 
-        for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
-            kv_cache_free(state->decoders[i].kv_self);
-        }
-
 #ifdef WHISPER_USE_COREML
         if (state->ctx_coreml != nullptr) {
             whisper_coreml_free(state->ctx_coreml);
@@ -3203,6 +3397,8 @@ void whisper_free_state(struct whisper_state * state)
         }
 #endif
 
+        whisper_batch_free(state->batch);
+
         whisper_allocr_free(state->alloc_conv);
         whisper_allocr_free(state->alloc_encode);
         whisper_allocr_free(state->alloc_cross);
@@ -3329,9 +3525,11 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
 }
 
 int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
-    const int selected_decoder_id = 0;
+    whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
+
+    whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1);
 
-    if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
+    if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
         WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return 1;
     }
@@ -3340,15 +3538,16 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
 }
 
 int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
-    // TODO: add selected_decoder_id to state
-    const int selected_decoder_id = 0;
-
     if (ctx->state == nullptr) {
         WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
         return false;
     }
 
-    if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
+    whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1);
+
+    whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past, 0);
+
+    if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->batch, n_threads, nullptr, nullptr)) {
         WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return 1;
     }
@@ -3436,7 +3635,7 @@ int whisper_lang_auto_detect_with_state(
         return -7;
     }
 
-    auto & logits_id = state->logits_id;
+    auto & logits_id = state->decoders[0].logits_id;
     logits_id.clear();
 
     for (const auto & kv : g_lang) {
@@ -3639,6 +3838,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
         const int32_t n_sample = std::max(1, ctx->state->n_sample);
         const int32_t n_encode = std::max(1, ctx->state->n_encode);
         const int32_t n_decode = std::max(1, ctx->state->n_decode);
+        const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
         const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
 
         WHISPER_LOG_INFO("%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
@@ -3646,6 +3846,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
         WHISPER_LOG_INFO("%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
         WHISPER_LOG_INFO("%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
         WHISPER_LOG_INFO("%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
+        WHISPER_LOG_INFO("%s:   batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
         WHISPER_LOG_INFO("%s:   prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
     }
     WHISPER_LOG_INFO("%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
@@ -3662,6 +3863,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
         ctx->state->n_sample = 0;
         ctx->state->n_encode = 0;
         ctx->state->n_decode = 0;
+        ctx->state->n_batchd = 0;
         ctx->state->n_prompt = 0;
     }
 }
@@ -3969,8 +4171,7 @@ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_
         if (*tok.code_points == 0) {
             // reached end of full codepoints in token, reject iff it ended in a partial sequence
             // that cannot satisfy this position in grammar
-            if (tok.partial_utf8.n_remain != 0 &&
-                    !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
+            if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
                 rejects.push_back(tok);
             }
         } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
@@ -4189,7 +4390,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.max_initial_ts    =*/  1.0f,
         /*.length_penalty    =*/ -1.0f,
 
-        /*.temperature_inc   =*/  0.4f,
+        /*.temperature_inc   =*/  0.2f,
         /*.entropy_thold     =*/  2.4f,
         /*.logprob_thold     =*/ -1.0f,
         /*.no_speech_thold   =*/  0.6f,
@@ -4229,13 +4430,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         case WHISPER_SAMPLING_GREEDY:
             {
                 result.greedy = {
-                    /*.best_of   =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
+                    /*.best_of   =*/ 5,
                 };
             } break;
         case WHISPER_SAMPLING_BEAM_SEARCH:
             {
                 result.beam_search = {
-                    /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
+                    /*.beam_size =*/ 5,
 
                     /*.patience  =*/ -1.0f,
                 };
@@ -4325,11 +4526,12 @@ static const std::vector<std::string> non_speech_tokens = {
 // process the logits for the selected decoder
 // - applies logit filters
 // - computes logprobs and probs
+// TODO: optimize
 static void whisper_process_logits(
               struct whisper_context & ctx,
                struct whisper_state  & state,
-    const struct whisper_full_params   params,
               struct whisper_decoder & decoder,
+    const struct whisper_full_params   params,
                                float   temperature) {
     const auto & vocab      = ctx.vocab;
     const auto & tokens_cur = decoder.sequence.tokens;
@@ -4346,7 +4548,7 @@ static void whisper_process_logits(
     auto & logprobs = decoder.logprobs;
     {
         logits.resize(n_logits);
-        memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
+        memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
 
         if (temperature > 0.0f) {
             for (int i = 0; i < n_logits; i++) {
@@ -4512,30 +4714,31 @@ static void whisper_process_logits(
             //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
 
             if (timestamp_logprob > max_text_token_logprob) {
-                //printf("sampling timestamp\n");
                 for (int i = 0; i < vocab.token_beg; ++i) {
                     logits[i]   = -INFINITY;
                     logprobs[i] = -INFINITY;
                 }
-            } else if (params.n_grammar_rules > 0) {
-                whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
+            } else {
+                if (params.n_grammar_rules > 0) {
+                    whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
 
-                // populate the logprobs array (log_softmax)
-                {
-                    const float logit_max = *std::max_element(logits.begin(), logits.end());
-                    float logsumexp = 0.0f;
-                    for (int i = 0; i < n_logits; ++i) {
-                        if (logits[i] > -INFINITY) {
-                            logsumexp += expf(logits[i] - logit_max);
+                    // populate the logprobs array (log_softmax)
+                    {
+                        const float logit_max = *std::max_element(logits.begin(), logits.end());
+                        float logsumexp = 0.0f;
+                        for (int i = 0; i < n_logits; ++i) {
+                            if (logits[i] > -INFINITY) {
+                                logsumexp += expf(logits[i] - logit_max);
+                            }
                         }
-                    }
-                    logsumexp = logf(logsumexp) + logit_max;
+                        logsumexp = logf(logsumexp) + logit_max;
 
-                    for (int i = 0; i < n_logits; ++i) {
-                        if (logits[i] > -INFINITY) {
-                            logprobs[i] = logits[i] - logsumexp;
-                        } else {
-                            logprobs[i] = -INFINITY;
+                        for (int i = 0; i < n_logits; ++i) {
+                            if (logits[i] > -INFINITY) {
+                                logprobs[i] = logits[i] - logsumexp;
+                            } else {
+                                logprobs[i] = -INFINITY;
+                            }
                         }
                     }
                 }
@@ -4610,7 +4813,6 @@ static void whisper_process_logits(
 
 static whisper_token_data whisper_sample_token(
             whisper_context & ctx,
-              whisper_state & state,
       const whisper_decoder & decoder,
                        bool   best) {
     whisper_token_data result = {
@@ -4655,7 +4857,7 @@ static whisper_token_data whisper_sample_token(
     } else {
         std::discrete_distribution<> dist(probs.begin(), probs.end());
 
-        result.id   = dist(state.rng);
+        result.id   = dist(decoder.rng);
         result.p    = probs[result.id];
         result.plog = logprobs[result.id];
     }
@@ -4665,15 +4867,12 @@ static whisper_token_data whisper_sample_token(
         result.pt  = result.p;
     }
 
-    state.n_sample++;
-
     return result;
 }
 
 static std::vector<whisper_token_data> whisper_sample_token_topk(
             whisper_context & ctx,
-              whisper_state & state,
-      const whisper_decoder & decoder,
+            whisper_decoder & decoder,
                         int   k) {
     const auto & vocab = ctx.vocab;
 
@@ -4683,7 +4882,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
 
     const int n_logits = vocab.n_vocab;
 
-    auto & logits_id = state.logits_id;
+    auto & logits_id = decoder.logits_id;
 
     logits_id.resize(n_logits);
     for (int i = 0; i < n_logits; ++i) {
@@ -4732,7 +4931,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
     std::discrete_distribution<> dist(probs.begin(), probs.end());
 
     for (int i = 0; i < k; ++i) {
-        const auto id = dist(state.rng);
+        const auto id = dist(decoder.rng);
         //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
 
         result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
@@ -4743,8 +4942,6 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
         }
     }
 
-    state.n_sample++;
-
     return result;
 }
 
@@ -4797,125 +4994,6 @@ static void whisper_sequence_score(
     }
 }
 
-static bool whisper_kv_swap_fast(
-                   std::vector<int> & view,
-                    whisper_decoder   src[],
-                std::vector<kv_buf> & kv_swap_bufs,
-                          const int & n_decoders) {
-    WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
-
-    // (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
-    std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
-
-    // (buffer->decoder or decoder->decoder)
-    std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
-
-    // (decoder<->decoder)
-    std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
-    std::vector<whisper_pair<int, int>> p_swap_vec;
-    p_swap_vec.reserve(n_decoders);
-
-    // see https://github.com/ggerganov/whisper.cpp/wiki
-    for (int i = 0; i < n_decoders; i++) {
-        // zero-copy (no modification)
-        if (i == view[i] || view[i] < 0) {
-            continue;
-        }
-
-        bool is_one_copy = true;
-        // since we modify data sequentially, we only consider decoder indices after current index
-        for (int j = i + 1; j < n_decoders; j++) {
-            if (i == view[j]) {
-                // detect symmetric diagram
-                if (j == view[i]) {
-                    p_swap_set.insert(i);
-                    p_swap_set.insert(j);
-                    p_swap_vec.emplace_back(i, j);
-                } else {
-                    two_copy.insert(i);
-                    is_one_copy = false;
-                }
-                break;
-            }
-        }
-        if (is_one_copy) {
-            one_copy.insert(i);
-        }
-    }
-
-    kv_swap_bufs.resize(n_decoders);
-
-    for (int i = 0; i < n_decoders; i++) {
-        kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k));
-        kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v));
-    }
-
-    for (auto & i : two_copy) {
-        // make a copy of KV caches
-        WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
-        //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
-        //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
-        ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
-        ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
-    }
-
-    // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
-    for (auto & i : two_copy) {
-        // skip the decoder indices that require pointer swapping
-        if (p_swap_set.find(i) != p_swap_set.end()) {
-            continue;
-        }
-
-        if (two_copy.find(view[i]) != two_copy.end()) {
-            // modify KV caches of decoder using data from kv_swap_bufs
-            WHISPER_PRINT_DEBUG("%s: two-copy decoder using   swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
-            //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
-            //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
-            ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
-            ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
-        } else {
-            // modify KV caches of decoder using data from correspond decoder KV caches directly
-            WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers:      %d  -> %d\n", __func__, view[i], i);
-            //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
-            //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
-            ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
-            ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
-        }
-    }
-
-    // then modify one-copy decoder KV caches
-    for (auto & i : one_copy) {
-        // skip the decoder indices that require pointer swapping
-        if (p_swap_set.find(i) != p_swap_set.end()) {
-            continue;
-        }
-
-        if (two_copy.find(view[i]) != two_copy.end()) {
-            // modify KV caches of decoder using data from kv_swap_bufs
-            WHISPER_PRINT_DEBUG("%s: one-copy decoder using   swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
-            //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
-            //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
-            ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
-            ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
-        } else {
-            // modify KV caches of decoder using data from correspond decoder KV caches directly
-            WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers:      %d  -> %d\n", __func__, view[i], i);
-            //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
-            //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
-            ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
-            ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
-        }
-    }
-
-    // swap the pointers
-    for (auto & i : p_swap_vec) {
-        WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
-        std::swap(src[i.first].kv_self, src[i.second].kv_self);
-    }
-
-    return true;
-}
-
 int whisper_full_with_state(
         struct whisper_context * ctx,
           struct whisper_state * state,
@@ -5005,25 +5083,23 @@ int whisper_full_with_state(
 
     n_decoders = std::max(1, n_decoders);
 
+    if (n_decoders > WHISPER_MAX_DECODERS) {
+        WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS);
+        return -4;
+    }
+
     // TAGS: WHISPER_DECODER_INIT
     for (int j = 1; j < n_decoders; j++) {
         auto & decoder = state->decoders[j];
 
-        if (decoder.kv_self.ctx == nullptr) {
-            decoder.kv_self = state->decoders[0].kv_self;
-            if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) {
-                WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
-                return -4;
-            }
-
-            WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
+        decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
 
-            decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
+        decoder.probs.resize   (ctx->vocab.n_vocab);
+        decoder.logits.resize  (ctx->vocab.n_vocab);
+        decoder.logprobs.resize(ctx->vocab.n_vocab);
+        decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
 
-            decoder.probs.resize   (ctx->vocab.n_vocab);
-            decoder.logits.resize  (ctx->vocab.n_vocab);
-            decoder.logprobs.resize(ctx->vocab.n_vocab);
-        }
+        decoder.rng = std::mt19937(0);
     }
 
     // the accumulated text context so far
@@ -5100,8 +5176,10 @@ int whisper_full_with_state(
         bool has_ts;
 
         whisper_sequence sequence;
+        whisper_grammar grammar;
     };
 
+    std::vector<std::vector<beam_candidate>> bc_per_dec(n_decoders);
     std::vector<beam_candidate> beam_candidates;
 
     // main loop
@@ -5169,8 +5247,6 @@ int whisper_full_with_state(
             for (int j = 0; j < n_decoders_cur; ++j) {
                 auto & decoder = state->decoders[j];
 
-                decoder.kv_self.n = 0;
-
                 decoder.sequence.tokens.clear();
                 decoder.sequence.result_len       = 0;
                 decoder.sequence.sum_logprobs_all = 0.0;
@@ -5186,15 +5262,14 @@ int whisper_full_with_state(
                 decoder.has_ts    = false;
 
                 if (params.grammar_rules != nullptr) {
-                    decoder.grammar = whisper_grammar_init(
-                        params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
+                    decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
                 } else {
                     decoder.grammar = {};
                 }
             }
 
             // init prompt and kv cache for the current iteration
-            // run whisper_decoder() only for decoder 0 and copy the results for the other decoders
+            // TODO: do not recompute the prompt if it is the same as previous time
             {
                 prompt.clear();
 
@@ -5216,7 +5291,11 @@ int whisper_full_with_state(
                 }
                 WHISPER_PRINT_DEBUG("\n\n");
 
-                if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
+                whisper_kv_cache_clear(state->kv_self);
+
+                whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
+
+                if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
                     WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
                     return -7;
                 }
@@ -5224,20 +5303,14 @@ int whisper_full_with_state(
                 {
                     const int64_t t_start_sample_us = ggml_time_us();
 
-                    whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
+                    state->decoders[0].i_batch = prompt.size() - 1;
 
-                    state->decoders[0].kv_self.n += prompt.size();
+                    whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
 
                     for (int j = 1; j < n_decoders_cur; ++j) {
                         auto & decoder = state->decoders[j];
 
-                        // TODO: fix CUDA
-                        //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
-                        //memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
-                        ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
-                        ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
-
-                        decoder.kv_self.n += prompt.size();
+                        whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1);
 
                         memcpy(decoder.probs.data(),    state->decoders[0].probs.data(),    decoder.probs.size()*sizeof(decoder.probs[0]));
                         memcpy(decoder.logits.data(),   state->decoders[0].logits.data(),   decoder.logits.size()*sizeof(decoder.logits[0]));
@@ -5252,41 +5325,81 @@ int whisper_full_with_state(
                 const int64_t t_start_sample_us = ggml_time_us();
 
                 if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
-                    beam_candidates.clear();
+                    for (auto & bc : bc_per_dec) {
+                        bc.clear();
+                    }
                 }
 
-                // generate new sequence candidates for each decoder
-                for (int j = 0; j < n_decoders_cur; ++j) {
-                    auto & decoder = state->decoders[j];
+                // sampling
+                // TODO: avoid memory allocations, optimize, avoid threads?
+                {
+                    std::atomic<int> j_cur(0);
 
-                    if (decoder.completed || decoder.failed) {
-                        continue;
-                    }
+                    auto process = [&]() {
+                        while (true) {
+                            const int j = j_cur.fetch_add(1);
 
-                    switch (params.strategy) {
-                        case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
-                            {
-                                if (t_cur < 1e-6f) {
-                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
-                                } else {
-                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
-                                }
+                            if (j >= n_decoders_cur) {
+                                break;
+                            }
 
-                                decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
-                            } break;
-                        case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
-                            {
-                                const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
+                            auto & decoder = state->decoders[j];
 
-                                for (const auto & token : tokens_new) {
-                                    beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
-                                    beam_candidates.back().sequence.tokens.push_back(token);
-                                    beam_candidates.back().sequence.sum_logprobs_all += token.plog;
+                            if (decoder.completed || decoder.failed) {
+                                continue;
+                            }
 
-                                    //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
-                                }
-                            } break;
+                            switch (params.strategy) {
+                                case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
+                                    {
+                                        if (t_cur < 1e-6f) {
+                                            decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
+                                        } else {
+                                            decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
+                                        }
+
+                                        decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
+                                    } break;
+                                case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
+                                    {
+                                        const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
+
+                                        for (const auto & token : tokens_new) {
+                                            bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
+                                            bc_per_dec[j].back().sequence.tokens.push_back(token);
+                                            bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
+                                        }
+                                    } break;
+                            };
+                        }
                     };
+
+                    const int n_threads = std::min(params.n_threads, n_decoders_cur);
+
+                    if (n_threads == 1) {
+                        process();
+                    } else {
+                        std::vector<std::thread> threads(n_threads - 1);
+
+                        for (int t = 0; t < n_threads - 1; ++t) {
+                            threads[t] = std::thread(process);
+                        }
+
+                        process();
+
+                        for (int t = 0; t < n_threads - 1; ++t) {
+                            threads[t].join();
+                        }
+                    }
+                }
+
+                beam_candidates.clear();
+                for (const auto & bc : bc_per_dec) {
+                    beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end());
+
+                    if (!bc.empty()) {
+                        state->n_sample += 1;
+                    }
                 }
 
                 // for beam-search, choose the top candidates and update the KV caches
@@ -5299,7 +5412,6 @@ int whisper_full_with_state(
                     });
 
                     uint32_t cur_c = 0;
-                    std::vector<int> decoder_idx(n_decoders_cur, -1);
 
                     for (int j = 0; j < n_decoders_cur; ++j) {
                         auto & decoder = state->decoders[j];
@@ -5318,17 +5430,28 @@ int whisper_full_with_state(
                             ++cur_c;
                         }
 
-                        decoder.sequence   = cur.sequence;
                         decoder.seek_delta = cur.seek_delta;
                         decoder.has_ts     = cur.has_ts;
+                        decoder.sequence   = cur.sequence;
+                        decoder.grammar    = cur.grammar;
+
+                        whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
 
-                        decoder_idx[j] = cur.decoder_idx;
                         WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
                                 __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
                     }
 
-                    // update KV caches
-                    whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
+                    for (int j = 0; j < n_decoders_cur; ++j) {
+                        auto & decoder = state->decoders[j];
+
+                        if (decoder.completed || decoder.failed) {
+                            continue;
+                        }
+
+                        whisper_kv_cache_seq_rm(state->kv_self, j,                           -1, -1);
+                        whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1);
+                        whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j,    -1, -1);
+                    }
                 }
 
                 // update the decoder state
@@ -5437,32 +5560,83 @@ int whisper_full_with_state(
                 state->t_sample_us += ggml_time_us() - t_start_sample_us;
 
                 // obtain logits for the next token
-                for (int j = 0; j < n_decoders_cur; ++j) {
-                    auto & decoder = state->decoders[j];
+                {
+                    auto & batch = state->batch;
 
-                    if (decoder.failed || decoder.completed) {
-                        continue;
-                    }
+                    batch.n_tokens = 0;
+
+                    const int n_past = prompt.size() + i;
+
+                    for (int j = 0; j < n_decoders_cur; ++j) {
+                        auto & decoder = state->decoders[j];
 
-                    decoder.tokens_tmp.resize(1);
-                    decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
+                        if (decoder.failed || decoder.completed) {
+                            continue;
+                        }
 
-                    //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
+                        //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
 
-                    if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
+                        decoder.i_batch = batch.n_tokens;
+
+                        batch.token   [batch.n_tokens]    = decoder.sequence.tokens.back().id;
+                        batch.pos     [batch.n_tokens]    = n_past;
+                        batch.n_seq_id[batch.n_tokens]    = 1;
+                        batch.seq_id  [batch.n_tokens][0] = j;
+                        batch.logits  [batch.n_tokens]    = 1;
+                        batch.n_tokens++;
+                    }
+
+                    assert(batch.n_tokens > 0);
+
+                    if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
                         WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
                         return -8;
                     }
 
+                    const int64_t t_start_sample_us = ggml_time_us();
+
+                    // TODO: avoid memory allocations, optimize, avoid threads?
                     {
-                        const int64_t t_start_sample_us = ggml_time_us();
+                        std::atomic<int> j_cur(0);
+
+                        auto process = [&]() {
+                            while (true) {
+                                const int j = j_cur.fetch_add(1);
+
+                                if (j >= n_decoders_cur) {
+                                    break;
+                                }
 
-                        whisper_process_logits(*ctx, *state, params, decoder, t_cur);
+                                auto & decoder = state->decoders[j];
 
-                        ++decoder.kv_self.n;
+                                if (decoder.failed || decoder.completed) {
+                                    continue;
+                                }
 
-                        state->t_sample_us += ggml_time_us() - t_start_sample_us;
+                                whisper_process_logits(*ctx, *state, decoder, params, t_cur);
+                            }
+                        };
+
+                        const int n_threads = std::min(params.n_threads, n_decoders_cur);
+
+                        if (n_threads == 1) {
+                            process();
+                        } else {
+                            std::vector<std::thread> threads(n_threads - 1);
+
+                            for (int t = 0; t < n_threads - 1; ++t) {
+                                threads[t] = std::thread(process);
+                            }
+
+                            process();
+
+                            for (int t = 0; t < n_threads - 1; ++t) {
+                                threads[t].join();
+                            }
+                        }
                     }
+
+                    state->t_sample_us += ggml_time_us() - t_start_sample_us;
                 }
             }
 
@@ -5759,11 +5933,13 @@ int whisper_full_parallel(
         ctx->state->t_sample_us += states[i]->t_sample_us;
         ctx->state->t_encode_us += states[i]->t_encode_us;
         ctx->state->t_decode_us += states[i]->t_decode_us;
+        ctx->state->t_batchd_us += states[i]->t_batchd_us;
         ctx->state->t_prompt_us += states[i]->t_prompt_us;
 
         ctx->state->n_sample += states[i]->n_sample;
         ctx->state->n_encode += states[i]->n_encode;
         ctx->state->n_decode += states[i]->n_decode;
+        ctx->state->n_batchd += states[i]->n_batchd;
         ctx->state->n_prompt += states[i]->n_prompt;
 
         whisper_free_state(states[i]);
index 50f84a822d083e8ac6965b66093f2b5af1146538..84540989d2ed339fb1ef4e834fa222af44e49e7b 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -78,7 +78,9 @@ extern "C" {
     struct whisper_state;
     struct whisper_full_params;
 
-    typedef int whisper_token;
+    typedef int32_t whisper_pos;
+    typedef int32_t whisper_token;
+    typedef int32_t whisper_seq_id;
 
     struct whisper_context_params {
         bool  use_gpu;