]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add full CUDA and Metal offloading (#1472)
authorGeorgi Gerganov <redacted>
Sun, 12 Nov 2023 13:31:08 +0000 (15:31 +0200)
committerGitHub <redacted>
Sun, 12 Nov 2023 13:31:08 +0000 (15:31 +0200)
* whisper : migrate to ggml-backend

* whisper : fix logit reading

* whisper : fix tensor allocation during load

* whisper : fix beam-search with CUDA

* whisper : free backends + fix compile warning

* whisper : print when CUDA is enabled

* whisper : fix CoreML

* make : clean-up

* talk : fix compile warning

* whisper : support ggml_conv with CUDA and Metal (#1473)

* ggml : add CUDA support for ggml_conv

* whisper : remove ggml_repeat for conv bias + single backend

* cuda : fix im2col kernel

* metal : add im2col support + mul mat-vec f16 x f16

* bench-all : add q4 models

* whisper : clean-up

* quantize-all : fix

* ggml : im2col opts

* whisper : avoid whisper_model_data wrapper

* whisper : add note that ggml_mul_mat_pad does not work with CUDA

* whisper : factor out graph compute in common function

* whisper : fixes

* whisper : fix UB with measure buffers

* whisper : try to fix the parallel whisper_state functionality (#1479)

* whisper : try to fix the parallel whisper_state functionality

* whisper : fix multi-state Metal

* whisper : free backend instances in whisper_state

14 files changed:
.gitignore
Makefile
examples/common.h
examples/talk/gpt-2.cpp
extra/bench-all.sh
extra/quantize-all.sh
ggml-cuda.cu
ggml-metal.h
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h
whisper.cpp
whisper.h

index d5c4b0caf484b2f188db02142d53d8b5fe534aca..9ff35d00b6ac5378eed4bb55aee5d000cc74231a 100644 (file)
@@ -8,6 +8,7 @@
 .DS_Store
 
 build/
+build-coreml/
 build-em/
 build-debug/
 build-release/
index d134b768bc6a659ed842dde4a6d6d333972a4bf5..20b02c3a8364a0ea83b983e64dd337c631d962cb 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -307,7 +307,7 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
 ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
        $(CC)  $(CFLAGS)   -c $< -o $@
 
-WHISPER_OBJ += ggml-alloc.o ggml-backend.o ggml-quants.o
+WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
 
 whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
        $(CXX) $(CXXFLAGS) -c $< -o $@
@@ -331,11 +331,11 @@ ggml-metal.o: ggml-metal.m ggml-metal.h
 WHISPER_OBJ += ggml-metal.o
 endif
 
-libwhisper.a: ggml.o $(WHISPER_OBJ)
-       $(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
+libwhisper.a: $(WHISPER_OBJ)
+       $(AR) rcs libwhisper.a $(WHISPER_OBJ)
 
-libwhisper.so: ggml.o $(WHISPER_OBJ)
-       $(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
+libwhisper.so: $(WHISPER_OBJ)
+       $(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS)
 
 clean:
        rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
@@ -349,30 +349,30 @@ CC_SDL=`sdl2-config --cflags --libs`
 SRC_COMMON     = examples/common.cpp examples/common-ggml.cpp
 SRC_COMMON_SDL = examples/common-sdl.cpp
 
-main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
-       $(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
+main: examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ)
+       $(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o main $(LDFLAGS)
        ./main -h
 
-bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
-       $(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
+bench: examples/bench/bench.cpp $(WHISPER_OBJ)
+       $(CXX) $(CXXFLAGS) examples/bench/bench.cpp $(WHISPER_OBJ) -o bench $(LDFLAGS)
 
-quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
-       $(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o quantize $(LDFLAGS)
+quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
+       $(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS)
 
-stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
-       $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
+stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
+       $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
 
-command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
-       $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
+command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
+       $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
 
-lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
-       $(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
+lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
+       $(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
 
-talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
-       $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
+talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
+       $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
 
-talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
-       $(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
+talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
+       $(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
 
 #
 # Audio samples
index 9a94bab7a7099c569f6f559ca1e6a15547ce4f3e..54f0b00d0ef41fd4e65c969663536c85bcb7680c 100644 (file)
@@ -181,7 +181,7 @@ private:
     // It is assumed that PCM data is normalized to a range from -1 to 1
     bool write_audio(const float * data, size_t length) {
         for (size_t i = 0; i < length; ++i) {
-            const auto intSample = static_cast<const int16_t>(data[i] * 32767);
+            const int16_t intSample = data[i] * 32767;
             file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
             dataSize += sizeof(int16_t);
         }
index a2319db6be6a10a6edac2346a2b8897b57c0a8d7..8f9a3e93b76f6fd889870be9875c37806d18c79e 100644 (file)
@@ -121,13 +121,13 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
             return false;
         }
 
-        std::string word;
+        char word[129];
+
         for (int i = 0; i < n_vocab; i++) {
             uint32_t len;
             fin.read((char *) &len, sizeof(len));
-
-            word.resize(len);
-            fin.read((char *) word.data(), len);
+            word[len] = '\0';
+            fin.read((char *) word, len);
 
             vocab.token_to_id[word] = i;
             vocab.id_to_token[i] = word;
index 8fd18b7d16a35064a386149f1379ea3fdf4c6c59..db042673d698e3872a4467edfede25f41e9711d6 100755 (executable)
@@ -18,11 +18,11 @@ else
 fi
 
 models=(                                               \
-      "tiny"   "tiny-q5_0"   "tiny-q5_1"   "tiny-q8_0" \
-      "base"   "base-q5_0"   "base-q5_1"   "base-q8_0" \
-     "small"  "small-q5_0"  "small-q5_1"  "small-q8_0" \
-    "medium" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
-     "large"  "large-q5_0"  "large-q5_1"  "large-q8_0" \
+      "tiny"   "tiny-q4_0"   "tiny-q4_1"   "tiny-q5_0"   "tiny-q5_1"   "tiny-q8_0" \
+      "base"   "base-q4_0"   "base-q4_1"   "base-q5_0"   "base-q5_1"   "base-q8_0" \
+     "small"  "small-q4_0"  "small-q4_1"  "small-q5_0"  "small-q5_1"  "small-q8_0" \
+    "medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
+     "large"  "large-q4_0"  "large-q4_1"  "large-q5_0"  "large-q5_1"  "large-q8_0" \
 )
 
 if [ "$encoder_only" -eq 0 ]; then
@@ -83,6 +83,10 @@ for model in "${models[@]}"; do
         config="$config COREML"
     fi
 
+    if [[ $system_info == *"CUDA = 1"* ]]; then
+        config="$config CUDA"
+    fi
+
     if [[ $system_info == *"METAL = 1"* ]]; then
         config="$config METAL"
     fi
index bfef21eddada0d1ab939b9548139ab4f2ce01689..767462b81074f86a81627af5fefa5efef0c07d53 100755 (executable)
@@ -15,33 +15,13 @@ declare -a filedex
 cd `dirname $0`
 cd ../
 
-# Let's loop across all the objects in the 'models' dir:
-for i in ./models/*; do
-    # Check to see if it's a file or directory
-    if [ -d "$i" ]; then
-        # It's a directory! We should make sure it's not empty first:
-        if [ "$(ls -A $i)" ]; then
-            # Passed! Let's go searching for bin files (shouldn't need to go more than a layer deep here)
-            for f in "$i"/*.bin; do
-                # [Neuron Activation]
-                newfile=`echo "${f##*/}" | cut -d _ -f 1`;
-                if [ "$newfile" != "q5" ]; then
-                    ./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" ${qtype1};
-                    ./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" ${qtype0};
-                    filedex+=( "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" )
-                fi
-            done
-        fi
-    else
-        # It's a file! Let's make sure it's the right type:
-        if [ "${i##*.}" == "bin" ]; then
-            # And we probably want to skip the testing files
-            if [ "${i:9:8}" != "for-test" ]; then
-                # [Neuron Activation]
-                ./quantize "${i}" "${i:-4}-${qtype1}.bin" ${qtype1};
-                ./quantize "${i}" "${i:-4}-${qtype0}.bin" ${qtype0};
-                filedex+=( "${i:-4}-${qtype1}.bin" "${i:-4}-${qtype0}.bin" )
-            fi
+for i in `ls ./models | grep ^ggml-.*.bin | grep -v "\-q"`; do
+    m="models/$i"
+    if [ -f "$m" ]; then
+        if [ "${m##*.}" == "bin" ]; then
+            ./quantize "${m}" "${m::${#m}-4}-${qtype1}.bin" ${qtype1};
+            ./quantize "${m}" "${m::${#m}-4}-${qtype0}.bin" ${qtype0};
+            filedex+=( "${m::${#m}-4}-${qtype1}.bin" "${m::${#m}-4}-${qtype0}.bin" )
         fi
     fi
 done
index f4a6795594df54e451ed88fb2f9caf41f5473541..34c45f3882bcf9d01ddcffcebfed85605f82a0d0 100644 (file)
@@ -4476,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
     *dsti = __float2half(*xi);
 }
 
+static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
+    const half * xi = (const half *) cxi;
+    half * dsti = (half *) cdsti;
+
+    *dsti = *xi;
+}
+
 template <cpy_kernel_t cpy_1>
 static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
                                    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -4729,6 +4736,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
     dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
 }
 
+static  __global__ void im2col_f32_f16(
+        const float * x, half * dst,
+        int ofs0, int ofs1, int IW, int IH, int CHW,
+        int s0, int s1, int p0, int p1, int d0, int d1) {
+    const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
+       const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
+
+    const int offset_dst =
+        (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
+        (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        dst[offset_dst] = __float2half(0.0f);
+    } else {
+        const int offset_src =  threadIdx.x * ofs0 + blockIdx.x * ofs1;
+        dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
+    }
+}
+
 template<int qk, int qr, dequantize_kernel_t dq>
 static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@@ -5618,6 +5644,16 @@ static void ggml_cpy_f32_f16_cuda(
         (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
 }
 
+static void ggml_cpy_f16_f16_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
 static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@@ -5701,6 +5737,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
 }
 
+static void im2col_f32_f16_cuda(const float * x, half * dst,
+    int OH, int IW, int IH, int OW, int IC,
+    int KH, int KW, int N,  int ofs0, int ofs1,
+    int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
+    dim3 block_nums(IC, OH, OW);
+    dim3 block_dims(N,  KH, KW);
+    im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+}
+
 // buffer pool for cuda
 #define MAX_CUDA_BUFFERS 256
 
@@ -6483,7 +6528,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
             src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
             to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
         }
-        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
+        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
         size_t dst_f16_as = 0;
         half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
 
@@ -6659,6 +6704,45 @@ inline void ggml_cuda_op_alibi(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_im2col(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+
+    const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+
+    const int64_t N  = src1->ne[is_2D ? 3 : 2];
+    const int64_t IC = src1->ne[is_2D ? 2 : 1];
+    const int64_t IH = is_2D ? src1->ne[1] : 1;
+    const int64_t IW =         src1->ne[0];
+
+    const int64_t KH = is_2D ? src0->ne[1] : 1;
+    const int64_t KW =         src0->ne[0];
+
+    const int64_t OH = is_2D ? dst->ne[2] : 1;
+    const int64_t OW =         dst->ne[1];
+
+    const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
+    const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+
+    im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
+        OH, IW, IH, OW, IC, KH, KW, N,
+        ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
+
+    (void) src0;
+    (void) src0_dd;
+}
+
 inline void ggml_cuda_op_diag_mask_inf(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7549,6 +7633,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
         ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
                               ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
+                              ne10, ne11, nb10, nb11, nb12, main_stream);
     } else {
         fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7580,6 +7667,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
 }
 
+void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
+}
+
 static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     (void) src0;
     (void) src1;
@@ -7943,6 +8034,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_ALIBI:
             func = ggml_cuda_alibi;
             break;
+        case GGML_OP_IM2COL:
+            func = ggml_cuda_im2col;
+            break;
         default:
             return false;
     }
index 096b844e32c6fef91e18033c4638acecc2f322c8..be2731f8ba476728c50b4971a67a74a9437b981d 100644 (file)
@@ -26,7 +26,7 @@
 #include <stdbool.h>
 
 // max memory buffers that can be mapped to the device
-#define GGML_METAL_MAX_BUFFERS 16
+#define GGML_METAL_MAX_BUFFERS 64
 #define GGML_METAL_MAX_COMMAND_BUFFERS 32
 
 struct ggml_tensor;
index 3bee83970b4c3d0e0d9bff3077e9af4d118c8101..6293908ca6fe5ce6446f22ed87e42fe02013f020 100644 (file)
@@ -86,6 +86,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@@ -114,6 +115,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rope_f32);
     GGML_METAL_DECL_KERNEL(rope_f16);
     GGML_METAL_DECL_KERNEL(alibi_f32);
+    GGML_METAL_DECL_KERNEL(im2col_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@@ -287,6 +289,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@@ -317,6 +320,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rope_f32);
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
+        GGML_METAL_ADD_KERNEL(im2col_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@@ -386,6 +390,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(norm);
     GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@@ -416,6 +421,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rope_f32);
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
+    GGML_METAL_DEL_KERNEL(im2col_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@@ -473,6 +479,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
 
     const int64_t tsize = ggml_nbytes(t);
 
+    if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
+        ctx = t->buffer->backend->context;
+    }
+
     // find the view that contains the tensor fully
     for (int i = 0; i < ctx->n_buffers; ++i) {
         const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
@@ -1139,6 +1149,7 @@ void ggml_metal_graph_compute(
                                 switch (src0t) {
                                     case GGML_TYPE_F32:
                                         {
+                                            GGML_ASSERT(src1t == GGML_TYPE_F32);
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
                                             nrows = 4;
                                         } break;
@@ -1146,13 +1157,18 @@ void ggml_metal_graph_compute(
                                         {
                                             nth0 = 32;
                                             nth1 = 1;
-                                            if (ne11 * ne12 < 4) {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
-                                            } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
-                                                nrows = ne11;
+                                            if (src1t == GGML_TYPE_F32) {
+                                                if (ne11 * ne12 < 4) {
+                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
+                                                } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
+                                                    nrows = ne11;
+                                                } else {
+                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
+                                                    nrows = 4;
+                                                }
                                             } else {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
                                                 nrows = 4;
                                             }
                                         } break;
@@ -1464,6 +1480,58 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
+                    case GGML_OP_IM2COL:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F16);
+                            GGML_ASSERT(src1->type == GGML_TYPE_F32);
+                            GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+                            const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+                            const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+                            const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+                            const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+                            const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+                            const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+                            const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+                            const int32_t N  = src1->ne[is_2D ? 3 : 2];
+                            const int32_t IC = src1->ne[is_2D ? 2 : 1];
+                            const int32_t IH = is_2D ? src1->ne[1] : 1;
+                            const int32_t IW =         src1->ne[0];
+
+                            const int32_t KH = is_2D ? src0->ne[1] : 1;
+                            const int32_t KW =         src0->ne[0];
+
+                            const int32_t OH = is_2D ? dst->ne[2] : 1;
+                            const int32_t OW =         dst->ne[1];
+
+                            const int32_t CHW = IC * KH * KW;
+
+                            const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+                            const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
+
+                            switch (src0->type) {
+                                case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
+                                case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
+                                default: GGML_ASSERT(false);
+                            };
+
+                            [encoder setBuffer:id_src1 offset:offs_src1        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ofs0    length:sizeof( int32_t) atIndex:2];
+                            [encoder setBytes:&ofs1    length:sizeof( int32_t) atIndex:3];
+                            [encoder setBytes:&IW      length:sizeof( int32_t) atIndex:4];
+                            [encoder setBytes:&IH      length:sizeof( int32_t) atIndex:5];
+                            [encoder setBytes:&CHW     length:sizeof( int32_t) atIndex:6];
+                            [encoder setBytes:&s0      length:sizeof( int32_t) atIndex:7];
+                            [encoder setBytes:&s1      length:sizeof( int32_t) atIndex:8];
+                            [encoder setBytes:&p0      length:sizeof( int32_t) atIndex:9];
+                            [encoder setBytes:&p1      length:sizeof( int32_t) atIndex:10];
+                            [encoder setBytes:&d0      length:sizeof( int32_t) atIndex:11];
+                            [encoder setBytes:&d1      length:sizeof( int32_t) atIndex:12];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+                        } break;
                     case GGML_OP_DUP:
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
index 7c35f23a7612fd75362457b3fc9d137cd37e0bfa..5d1357cd72d4592782802a60222e81d2cacb8d8f 100644 (file)
@@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]]) {
 
     const int64_t r0 = tgpig.x;
     const int64_t rb = tgpig.y*N_F32_F32;
@@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
     }
 }
 
+#define N_F16_F16 4
+
+kernel void kernel_mul_mv_f16_f16(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t rb = tgpig.y*N_F16_F16;
+    const int64_t im = tgpig.z;
+
+    device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (half) x[i] * (half) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const half4 * x4 = (device const half4 *)x;
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const half  * y  = (device const half  *) (src1 + r1*nb11 + im*nb12);
+            device const half4 * y4 = (device const half4 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
 kernel void kernel_mul_mv_f16_f32_1row(
         device const  char * src0,
         device const  char * src1,
@@ -1229,6 +1302,39 @@ kernel void kernel_rope(
 template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
 template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
 
+kernel void kernel_im2col_f16(
+        device const float * x,
+        device       half * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
+    const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
+
+    const int32_t offset_dst =
+        (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+        (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        dst[offset_dst] = 0.0f;
+    } else {
+        const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
+        dst[offset_dst] = x[offset_src + iih * IW + iiw];
+    }
+}
+
 kernel void kernel_cpy_f16_f16(
         device const half * src0,
         device       half * dst,
diff --git a/ggml.c b/ggml.c
index d1b7e94ddc2ebbe1480fd5b7bacf6cab1f91b4f2..584ee4680378d9ddb3722d61e618d9837a07a88d 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1634,13 +1634,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ROPE_BACK",
     "ALIBI",
     "CLAMP",
-    "CONV_1D",
-    "CONV_1D_STAGE_0",
-    "CONV_1D_STAGE_1",
     "CONV_TRANSPOSE_1D",
-    "CONV_2D",
-    "CONV_2D_STAGE_0",
-    "CONV_2D_STAGE_1",
+    "IM2COL",
     "CONV_TRANSPOSE_2D",
     "POOL_1D",
     "POOL_2D",
@@ -1671,7 +1666,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
+static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1721,13 +1716,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "rope_back(x)",
     "alibi(x)",
     "clamp(x)",
-    "conv_1d(x)",
-    "conv_1d_stage_0(x)",
-    "conv_1d_stage_1(x)",
     "conv_transpose_1d(x)",
-    "conv_2d(x)",
-    "conv_2d_stage_0(x)",
-    "conv_2d_stage_1(x)",
+    "im2col(x)",
     "conv_transpose_2d(x)",
     "pool_1d(x)",
     "pool_2d(x)",
@@ -1758,7 +1748,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
+static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1786,13 +1776,7 @@ static void ggml_setup_op_has_task_pass(void) {
         p[GGML_OP_GET_ROWS_BACK          ] = true;
         p[GGML_OP_DIAG_MASK_INF          ] = true;
         p[GGML_OP_DIAG_MASK_ZERO         ] = true;
-        p[GGML_OP_CONV_1D                ] = true;
-        p[GGML_OP_CONV_1D_STAGE_0        ] = true;
-        p[GGML_OP_CONV_1D_STAGE_1        ] = true;
         p[GGML_OP_CONV_TRANSPOSE_1D      ] = true;
-        p[GGML_OP_CONV_2D                ] = true;
-        p[GGML_OP_CONV_2D_STAGE_0        ] = true;
-        p[GGML_OP_CONV_2D_STAGE_1        ] = true;
         p[GGML_OP_CONV_TRANSPOSE_2D      ] = true;
         p[GGML_OP_FLASH_ATTN_BACK        ] = true;
         p[GGML_OP_CROSS_ENTROPY_LOSS     ] = true;
@@ -5137,82 +5121,6 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p,
     return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
 }
 
-// im2col: [N, IC, IL] => [N, OL, IC*K]
-// a: [OC,IC, K]
-// b: [N, IC, IL]
-// result: [N, OL, IC*K]
-static struct ggml_tensor * ggml_conv_1d_stage_0(
-    struct ggml_context * ctx,
-    struct ggml_tensor  * a,
-    struct ggml_tensor  * b,
-    int                   s0,
-    int                   p0,
-    int                   d0) {
-    GGML_ASSERT(a->ne[1] == b->ne[1]);
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
-        is_node = true;
-    }
-
-    const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
-
-    const int64_t ne[4] = {
-        a->ne[1] * a->ne[0],
-        OL,
-        b->ne[2],
-        1,
-    };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
-
-    int32_t params[] = { s0, p0, d0 };
-    ggml_set_op_params(result, params, sizeof(params));
-
-    result->op = GGML_OP_CONV_1D_STAGE_0;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-}
-
-// ggml_conv_1d_stage_1
-
-// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
-// a: [OC, IC, K]
-// b: [N, OL, IC * K]
-// result: [N, OC, OL]
-static struct ggml_tensor * ggml_conv_1d_stage_1(
-    struct ggml_context * ctx,
-    struct ggml_tensor  * a,
-    struct ggml_tensor  * b) {
-
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
-        is_node = true;
-    }
-
-    const int64_t ne[4] = {
-        b->ne[1],
-        a->ne[2],
-        b->ne[2],
-        1,
-    };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
-
-    result->op = GGML_OP_CONV_1D_STAGE_1;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-}
-
-// ggml_conv_1d
-
 GGML_API struct ggml_tensor * ggml_conv_1d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -5220,43 +5128,17 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
         int                   s0,
         int                   p0,
         int                   d0) {
-    struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0);
-    result = ggml_conv_1d_stage_1(ctx, a, result);
-    return result;
-}
+    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
 
-// GGML_API struct ggml_tensor * ggml_conv_1d(
-//         struct ggml_context * ctx,
-//         struct ggml_tensor  * a,
-//         struct ggml_tensor  * b,
-//         int                   s0,
-//         int                   p0,
-//         int                   d0) {
-//     GGML_ASSERT(ggml_is_matrix(b));
-//     GGML_ASSERT(a->ne[1] == b->ne[1]);
-//     bool is_node = false;
+    struct ggml_tensor * result =
+        ggml_mul_mat(ctx,
+                ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
+                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2]));                    // [OC,IC, K] => [OC, IC * K]
 
-//     if (a->grad || b->grad) {
-//         GGML_ASSERT(false); // TODO: implement backward
-//         is_node = true;
-//     }
+    result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
 
-//     const int64_t ne[4] = {
-//         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
-//         a->ne[2], 1, 1,
-//     };
-//     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
-
-//     int32_t params[] = { s0, p0, d0 };
-//     ggml_set_op_params(result, params, sizeof(params));
-
-//     result->op = GGML_OP_CONV_1D;
-//     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-//     result->src[0] = a;
-//     result->src[1] = b;
-
-//     return result;
-// }
+    return result;
+}
 
 // ggml_conv_1d_ph
 
@@ -5319,7 +5201,7 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
 // a: [OC,IC, KH, KW]
 // b: [N, IC, IH, IW]
 // result: [N, OH, OW, IC*KH*KW]
-static struct ggml_tensor * ggml_conv_2d_stage_0(
+struct ggml_tensor * ggml_im2col(
     struct ggml_context * ctx,
     struct ggml_tensor  * a,
     struct ggml_tensor  * b,
@@ -5328,9 +5210,14 @@ static struct ggml_tensor * ggml_conv_2d_stage_0(
     int                  p0,
     int                  p1,
     int                  d0,
-    int                  d1) {
+    int                  d1,
+    bool                 is_2D) {
 
-    GGML_ASSERT(a->ne[2] == b->ne[2]);
+    if(is_2D) {
+        GGML_ASSERT(a->ne[2] == b->ne[2]);
+    } else {
+        GGML_ASSERT(a->ne[1] == b->ne[1]);
+    }
     bool is_node = false;
 
     if (a->grad || b->grad) {
@@ -5338,81 +5225,51 @@ static struct ggml_tensor * ggml_conv_2d_stage_0(
         is_node = true;
     }
 
-    const int64_t OH = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
-    const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
+    const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
+    const int64_t OW =         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
 
     const int64_t ne[4] = {
-        a->ne[2] * a->ne[1] * a->ne[0],
+        is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
         OW,
-        OH,
-        b->ne[3],
+        is_2D ? OH : b->ne[2],
+        is_2D ?      b->ne[3] : 1,
     };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
 
-    int32_t params[] = { s0, s1, p0, p1, d0, d1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
+    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op = GGML_OP_CONV_2D_STAGE_0;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-
-}
-
-// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
-// a: [OC, IC, KH, KW]
-// b: [N, OH, OW, IC * KH * KW]
-// result: [N, OC, OH, OW]
-static struct ggml_tensor * ggml_conv_2d_stage_1(
-    struct ggml_context * ctx,
-    struct ggml_tensor  * a,
-    struct ggml_tensor  * b) {
-
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
-        is_node = true;
-    }
-
-    const int64_t ne[4] = {
-        b->ne[1],
-        b->ne[2],
-        a->ne[3],
-        b->ne[3],
-    };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
-
-    result->op = GGML_OP_CONV_2D_STAGE_1;
+    result->op = GGML_OP_IM2COL;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
 
     return result;
-
 }
 
 // a: [OC,IC, KH, KW]
 // b: [N, IC, IH, IW]
 // result: [N, OC, OH, OW]
 struct ggml_tensor * ggml_conv_2d(
-    struct ggml_context * ctx,
-    struct ggml_tensor  * a,
-    struct ggml_tensor  * b,
-    int                  s0,
-    int                  s1,
-    int                  p0,
-    int                  p1,
-    int                  d0,
-    int                  d1) {
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                  s0,
+        int                  s1,
+        int                  p0,
+        int                  p1,
+        int                  d0,
+        int                  d1) {
+    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
 
-    struct ggml_tensor * result = ggml_conv_2d_stage_0(ctx, a, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW]
-    result = ggml_conv_2d_stage_1(ctx, a, result);
+    struct ggml_tensor * result =
+        ggml_mul_mat(ctx,
+                ggml_reshape_2d(ctx, im2col, im2col->ne[0],  im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
+                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]),  a->ne[3]));                       // [OC,IC, KH, KW] => [OC, IC * KH * KW]
 
-    return result;
+    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW]
 
+    return result;
 }
 
 // ggml_conv_2d_sk_p0
@@ -9507,6 +9364,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     // TODO: find the optimal values for these
     if (ggml_is_contiguous(src0) &&
         ggml_is_contiguous(src1) &&
+        src0->type == GGML_TYPE_F32 &&
+        src1->type == GGML_TYPE_F32 &&
         (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
 
         /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
@@ -9517,6 +9376,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
 }
 #endif
 
+
 static void ggml_compute_forward_mul_mat(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -9545,7 +9405,7 @@ static void ggml_compute_forward_mul_mat(
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == ggml_type_size(type));
-    GGML_ASSERT(nb10 == sizeof(float));
+    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
 
     // dst cannot be transposed or permuted
     GGML_ASSERT(nb0 == sizeof(float));
@@ -11637,9 +11497,9 @@ static void ggml_compute_forward_rope_back(
     }
 }
 
-// ggml_compute_forward_conv_1d
+// ggml_compute_forward_conv_transpose_1d
 
-static void ggml_compute_forward_conv_1d_f16_f32(
+static void ggml_compute_forward_conv_transpose_1d_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -11656,14 +11516,7 @@ static void ggml_compute_forward_conv_1d_f16_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nk = ne00;
-
-    // size of the convolution row - the kernel size unrolled across all input channels
-    const int ew0 = nk*ne01;
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-    const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
-    const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
+    const int nk = ne00*ne01*ne02;
 
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
@@ -11671,23 +11524,37 @@ static void ggml_compute_forward_conv_1d_f16_f32(
     if (params->type == GGML_TASK_INIT) {
         memset(params->wdata, 0, params->wsize);
 
-        ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 
-        for (int64_t i11 = 0; i11 < ne11; i11++) {
-            const float * const src = (float *)((char *) src1->data + i11*nb11);
-            ggml_fp16_t * dst_data = wdata;
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ne02 + i02] = src[i00];
+                    }
+                }
+            }
+        }
 
-            for (int64_t i0 = 0; i0 < ne0; i0++) {
-                for (int64_t ik = 0; ik < nk; ik++) {
-                    const int idx0 = i0*s0 + ik*d0 - p0;
+        // permute source data (src1) from (L x Cin) to (Cin x L)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
+            ggml_fp16_t * dst_data = wdata;
 
-                    if(!(idx0 < 0 || idx0 >= ne10)) {
-                        dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]);
-                    }
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
+                    dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
                 }
             }
         }
 
+        // need to zero dst since we are accumulating into it
+        memset(dst->data, 0, ggml_nbytes(dst));
+
         return;
     }
 
@@ -11695,8 +11562,10 @@ static void ggml_compute_forward_conv_1d_f16_f32(
         return;
     }
 
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+
     // total rows in dst
-    const int nr = ne2;
+    const int nr = ne1;
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -11705,22 +11574,26 @@ static void ggml_compute_forward_conv_1d_f16_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-
-    for (int i2 = 0; i2 < ne2; i2++) {
-        for (int i1 = ir0; i1 < ir1; i1++) {
-            float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
+    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
+    ggml_fp16_t * const wdata_src = wdata + nk;
 
-            for (int i0 = 0; i0 < ne0; i0++) {
-                ggml_vec_dot_f16(ew0, dst_data + i0,
-                        (ggml_fp16_t *) ((char *) src0->data + i1*nb02),
-                        (ggml_fp16_t *)                wdata + i2*nb2 + i0*ew0);
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
+        for (int i10 = 0; i10 < ne10; i10++) {
+            const int i1n = i10*ne11;
+            for (int i00 = 0; i00 < ne00; i00++) {
+                float v = 0;
+                ggml_vec_dot_f16(ne02, &v,
+                        (ggml_fp16_t *)    wdata_src + i1n,
+                        (ggml_fp16_t *) wdata_kernel + i00*ne02);
+                dst_data[i10*s0 + i00] += v;
             }
         }
     }
 }
 
-static void ggml_compute_forward_conv_1d_f32(
+static void ggml_compute_forward_conv_transpose_1d_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -11737,13 +11610,7 @@ static void ggml_compute_forward_conv_1d_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nk = ne00;
-
-    const int ew0 = nk*ne01;
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-    const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
-    const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
+    const int nk = ne00*ne01*ne02;
 
     GGML_ASSERT(nb00 == sizeof(float));
     GGML_ASSERT(nb10 == sizeof(float));
@@ -11751,23 +11618,37 @@ static void ggml_compute_forward_conv_1d_f32(
     if (params->type == GGML_TASK_INIT) {
         memset(params->wdata, 0, params->wsize);
 
-        float * const wdata = (float *) params->wdata + 0;
+        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
+        {
+            float * const wdata = (float *) params->wdata + 0;
 
-        for (int64_t i11 = 0; i11 < ne11; i11++) {
-            const float * const src = (float *)((char *) src1->data + i11*nb11);
-            float * dst_data = wdata;
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    float * dst_data = wdata + i01*ne00*ne02;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ne02 + i02] = src[i00];
+                    }
+                }
+            }
+        }
 
-            for (int64_t i0 = 0; i0 < ne0; i0++) {
-                for (int64_t ik = 0; ik < nk; ik++) {
-                    const int idx0 = i0*s0 + ik*d0 - p0;
+        // prepare source data (src1)
+        {
+            float * const wdata = (float *) params->wdata + nk;
+            float * dst_data = wdata;
 
-                    if(!(idx0 < 0 || idx0 >= ne10)) {
-                        dst_data[i0*ew0 + i11*nk + ik] = src[idx0];
-                    }
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
+                    dst_data[i10*ne11 + i11] = src[i10];
                 }
             }
         }
 
+        // need to zero dst since we are accumulating into it
+        memset(dst->data, 0, ggml_nbytes(dst));
+
         return;
     }
 
@@ -11775,8 +11656,10 @@ static void ggml_compute_forward_conv_1d_f32(
         return;
     }
 
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+
     // total rows in dst
-    const int nr = ne02;
+    const int nr = ne1;
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -11785,94 +11668,50 @@ static void ggml_compute_forward_conv_1d_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    float * const wdata = (float *) params->wdata + 0;
-
-    for (int i2 = 0; i2 < ne2; i2++) {
-        for (int i1 = ir0; i1 < ir1; i1++) {
-            float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
+    float * const wdata     = (float *) params->wdata + 0;
+    float * const wdata_src = wdata + nk;
 
-            for (int i0 = 0; i0 < ne0; i0++) {
-                ggml_vec_dot_f32(ew0, dst_data + i0,
-                        (float *) ((char *) src0->data + i1*nb02),
-                        (float *)                wdata + i2*nb2 + i0*ew0);
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        float * wdata_kernel = wdata + i1*ne02*ne00;
+        for (int i10 = 0; i10 < ne10; i10++) {
+            const int i1n = i10*ne11;
+            for (int i00 = 0; i00 < ne00; i00++) {
+                float v = 0;
+                ggml_vec_dot_f32(ne02, &v,
+                        wdata_src + i1n,
+                        wdata_kernel + i00*ne02);
+                dst_data[i10*s0 + i00] += v;
             }
         }
     }
 }
 
-// TODO: reuse ggml_mul_mat or implement ggml_im2col and remove stage_0 and stage_1
-static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
-                             ggml_fp16_t * A,
-                             ggml_fp16_t * B,
-                             float * C,
-                             const int ith, const int nth) {
-    // does not seem to make a difference
-    int64_t m0, m1, n0, n1;
-    // patches per thread
-    if (m > n) {
-        n0 = 0;
-        n1 = n;
-
-        // total patches in dst
-        const int np = m;
-
-        // patches per thread
-        const int dp = (np + nth - 1)/nth;
-
-        // patch range for this thread
-        m0 = dp*ith;
-        m1 = MIN(m0 + dp, np);
-    } else {
-        m0 = 0;
-        m1 = m;
-
-        // total patches in dst
-        const int np = n;
-
-        // patches per thread
-        const int dp = (np + nth - 1)/nth;
-
-        // patch range for this thread
-        n0 = dp*ith;
-        n1 = MIN(n0 + dp, np);
-    }
-
-    // block-tiling attempt
-    int64_t blck_n = 16;
-    int64_t blck_m = 16;
-
-    // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
-    // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K);
-    // if (blck_size > 0) {
-    //     blck_0 = 4;
-    //     blck_1 = blck_size / blck_0;
-    //     if (blck_1 < 0) {
-    //         blck_1 = 1;
-    //     }
-    //     // blck_0 = (int64_t)sqrt(blck_size);
-    //     // blck_1 = blck_0;
-    // }
-    // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
-
-    for (int j = n0; j < n1; j+=blck_n) {
-        for (int i = m0; i < m1; i+=blck_m) {
-            // printf("i j k => %d %d %d\n", i, j, K);
-            for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
-                for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
-                    ggml_vec_dot_f16(k,
-                                    C + ii*n + jj,
-                                    A + ii * k,
-                                    B + jj * k);
-                }
-            }
-        }
+static void ggml_compute_forward_conv_transpose_1d(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
     }
 }
 
-// src0: kernel [OC, IC, K]
-// src1: signal [N, IC, IL]
-// dst:  result [N, OL, IC*K]
-static void ggml_compute_forward_conv_1d_stage_0_f32(
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst:  result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -11886,26 +11725,35 @@ static void ggml_compute_forward_conv_1d_stage_0_f32(
 
     GGML_TENSOR_BINARY_OP_LOCALS;
 
-    const int64_t N  = ne12;
-    const int64_t IC = ne11;
-    const int64_t IL = ne10;
-
-    const int64_t K = ne00;
-
-    const int64_t OL = ne1;
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-    const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
-    const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
+    const int64_t N  = is_2D ? ne13 : ne12;
+    const int64_t IC = is_2D ? ne12 : ne11;
+    const int64_t IH = is_2D ? ne11 : 1;
+    const int64_t IW = ne10;
+
+    const int64_t KH = is_2D ? ne01 : 1;
+    const int64_t KW = ne00;
+
+    const int64_t OH = is_2D ? ne2 : 1;
+    const int64_t OW = ne1;
+
+    int ofs0 = is_2D ? nb13 : nb12;
+    int ofs1 = is_2D ? nb12 : nb11;
 
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
     if (params->type == GGML_TASK_INIT) {
-        memset(dst->data, 0, ggml_nbytes(dst));
         return;
     }
 
@@ -11913,23 +11761,30 @@ static void ggml_compute_forward_conv_1d_stage_0_f32(
         return;
     }
 
-    // im2col: [N, IC, IL] => [N, OL, IC*K]
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
     {
         ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
 
         for (int64_t in = 0; in < N; in++) {
-            for (int64_t iol = 0; iol < OL; iol++) {
-                for (int64_t iic = ith; iic < IC; iic+=nth) {
+            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
+                for (int64_t iow = 0; iow < OW; iow++) {
+                    for (int64_t iic = ith; iic < IC; iic += nth) {
 
-                    // micro kernel
-                    ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K]
-                    const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL]
+                        // micro kernel
+                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
 
-                    for (int64_t ik = 0; ik < K; ik++) {
-                        const int64_t iil = iol*s0 + ik*d0 - p0;
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
+                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
 
-                        if (!(iil < 0 || iil >= IL)) {
-                            dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]);
+                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+                                } else {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
+                                }
+                            }
                         }
                     }
                 }
@@ -11938,627 +11793,7 @@ static void ggml_compute_forward_conv_1d_stage_0_f32(
     }
 }
 
-// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
-// src0: [OC, IC, K]
-// src1: [N, OL, IC * K]
-// result: [N, OC, OL]
-static void ggml_compute_forward_conv_1d_stage_1_f16(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F16);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    if (params->type == GGML_TASK_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb0  == sizeof(float));
-
-    const int N = ne12;
-    const int OL = ne11;
-
-    const int OC = ne02;
-    const int IC = ne01;
-    const int K  = ne00;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    int64_t m = OC;
-    int64_t n = OL;
-    int64_t k = IC * K;
-
-    // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
-    for (int i = 0; i < N; i++) {
-        ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
-        ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
-        float * C = (float *)dst->data + i * m * n; // [m, n]
-
-        gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
-    }
-}
-
-static void ggml_compute_forward_conv_1d(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    switch(src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
-static void ggml_compute_forward_conv_1d_stage_0(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    switch(src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
-static void ggml_compute_forward_conv_1d_stage_1(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    switch(src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
-// ggml_compute_forward_conv_transpose_1d
-
-static void ggml_compute_forward_conv_transpose_1d_f16_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nk = ne00*ne01*ne02;
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (params->type == GGML_TASK_INIT) {
-        memset(params->wdata, 0, params->wsize);
-
-        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
-                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        dst_data[i00*ne02 + i02] = src[i00];
-                    }
-                }
-            }
-        }
-
-        // permute source data (src1) from (L x Cin) to (Cin x L)
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
-            ggml_fp16_t * dst_data = wdata;
-
-            for (int64_t i11 = 0; i11 < ne11; i11++) {
-                const float * const src = (float *)((char *) src1->data + i11*nb11);
-                for (int64_t i10 = 0; i10 < ne10; i10++) {
-                    dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
-                }
-            }
-        }
-
-        // need to zero dst since we are accumulating into it
-        memset(dst->data, 0, ggml_nbytes(dst));
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-
-    // total rows in dst
-    const int nr = ne1;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
-    ggml_fp16_t * const wdata_src = wdata + nk;
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        float * dst_data = (float *)((char *) dst->data + i1*nb1);
-        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
-        for (int i10 = 0; i10 < ne10; i10++) {
-            const int i1n = i10*ne11;
-            for (int i00 = 0; i00 < ne00; i00++) {
-                float v = 0;
-                ggml_vec_dot_f16(ne02, &v,
-                        (ggml_fp16_t *)    wdata_src + i1n,
-                        (ggml_fp16_t *) wdata_kernel + i00*ne02);
-                dst_data[i10*s0 + i00] += v;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_conv_transpose_1d_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nk = ne00*ne01*ne02;
-
-    GGML_ASSERT(nb00 == sizeof(float));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (params->type == GGML_TASK_INIT) {
-        memset(params->wdata, 0, params->wsize);
-
-        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
-        {
-            float * const wdata = (float *) params->wdata + 0;
-
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
-                    float * dst_data = wdata + i01*ne00*ne02;
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        dst_data[i00*ne02 + i02] = src[i00];
-                    }
-                }
-            }
-        }
-
-        // prepare source data (src1)
-        {
-            float * const wdata = (float *) params->wdata + nk;
-            float * dst_data = wdata;
-
-            for (int64_t i11 = 0; i11 < ne11; i11++) {
-                const float * const src = (float *)((char *) src1->data + i11*nb11);
-                for (int64_t i10 = 0; i10 < ne10; i10++) {
-                    dst_data[i10*ne11 + i11] = src[i10];
-                }
-            }
-        }
-
-        // need to zero dst since we are accumulating into it
-        memset(dst->data, 0, ggml_nbytes(dst));
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-
-    // total rows in dst
-    const int nr = ne1;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float * const wdata     = (float *) params->wdata + 0;
-    float * const wdata_src = wdata + nk;
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        float * dst_data = (float *)((char *) dst->data + i1*nb1);
-        float * wdata_kernel = wdata + i1*ne02*ne00;
-        for (int i10 = 0; i10 < ne10; i10++) {
-            const int i1n = i10*ne11;
-            for (int i00 = 0; i00 < ne00; i00++) {
-                float v = 0;
-                ggml_vec_dot_f32(ne02, &v,
-                        wdata_src + i1n,
-                        wdata_kernel + i00*ne02);
-                dst_data[i10*s0 + i00] += v;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_conv_transpose_1d(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
-// ggml_compute_forward_conv_2d
-
-// src0: kernel [OC, IC, KH, KW]
-// src1: image [N, IC, IH, IW]
-// dst:  result [N, OH, OW, IC*KH*KW]
-static void ggml_compute_forward_conv_2d_stage_0_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F16);
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const int64_t N = ne13;
-    const int64_t IC = ne12;
-    const int64_t IH = ne11;
-    const int64_t IW = ne10;
-
-    // const int64_t OC = ne03;
-    // const int64_t IC = ne02;
-    const int64_t KH = ne01;
-    const int64_t KW = ne00;
-
-    const int64_t OH = ne2;
-    const int64_t OW = ne1;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
-    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
-    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
-    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
-    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (params->type == GGML_TASK_INIT) {
-        memset(dst->data, 0, ggml_nbytes(dst));
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
-    {
-        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
-
-        for (int64_t in = 0; in < N; in++) {
-            for (int64_t ioh = 0; ioh < OH; ioh++) {
-                for (int64_t iow = 0; iow < OW; iow++) {
-                    for (int64_t iic = ith; iic < IC; iic+=nth) {
-
-                        // micro kernel
-                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                        const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
-
-                        for (int64_t ikh = 0; ikh < KH; ikh++) {
-                            for (int64_t ikw = 0; ikw < KW; ikw++) {
-                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
-                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
-
-                                if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
-                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
-// src0: [OC, IC, KH, KW]
-// src1: [N, OH, OW, IC * KH * KW]
-// result: [N, OC, OH, OW]
-static void ggml_compute_forward_conv_2d_stage_1_f16(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F16);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    if (params->type == GGML_TASK_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb0  == sizeof(float));
-
-    const int N = ne13;
-    const int OH = ne12;
-    const int OW = ne11;
-
-    const int OC = ne03;
-    const int IC = ne02;
-    const int KH = ne01;
-    const int KW = ne00;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    int64_t m = OC;
-    int64_t n = OH * OW;
-    int64_t k = IC * KH * KW;
-
-    // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
-    for (int i = 0; i < N; i++) {
-        ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
-        ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
-        float * C = (float *)dst->data + i * m * n; // [m, n]
-
-        gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
-    }
-}
-
-static void ggml_compute_forward_conv_2d_f16_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    // src1: image [N, IC, IH, IW]
-    // src0: kernel [OC, IC, KH, KW]
-    // dst:  result [N, OC, OH, OW]
-    // ne12: IC
-    // ne0: OW
-    // ne1: OH
-    // nk0: KW
-    // nk1: KH
-    // ne13: N
-
-    const int N = ne13;
-    const int IC = ne12;
-    const int IH = ne11;
-    const int IW = ne10;
-
-    const int OC = ne03;
-    // const int IC = ne02;
-    const int KH = ne01;
-    const int KW = ne00;
-
-    const int OH = ne1;
-    const int OW = ne0;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    // const int nk0 = ne00;
-    // const int nk1 = ne01;
-
-    // size of the convolution row - the kernel size unrolled across all channels
-    // const int ew0 = nk0*nk1*ne02;
-    // ew0: IC*KH*KW
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
-    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
-    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
-    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
-    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (params->type == GGML_TASK_INIT) {
-        memset(params->wdata, 0, params->wsize);
-
-        // prepare source data (src1)
-        // im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW]
-
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-
-            for (int in = 0; in < N; in++) {
-                for (int iic = 0; iic < IC; iic++) {
-                    for (int ioh = 0; ioh < OH; ioh++) {
-                        for (int iow = 0; iow < OW; iow++) {
-
-                            // micro kernel
-                            ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                            const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
-
-                            for (int ikh = 0; ikh < KH; ikh++) {
-                                for (int ikw = 0; ikw < KW; ikw++) {
-                                    const int iiw = iow*s0 + ikw*d0 - p0;
-                                    const int iih = ioh*s1 + ikh*d1 - p1;
-
-                                    if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
-                                        dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-        }
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-    // wdata: [N*OH*OW, IC*KH*KW]
-    // dst: result [N, OC, OH, OW]
-    // src0: kernel [OC, IC, KH, KW]
-
-    int64_t m = OC;
-    int64_t n = OH * OW;
-    int64_t k = IC * KH * KW;
-
-    // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
-    for (int i = 0; i < N; i++) {
-        ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
-        ggml_fp16_t * B = (ggml_fp16_t *)wdata + i * m * k; // [n, k]
-        float * C = (float *)dst->data + i * m * n; // [m * k]
-
-        gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
-    }
-}
-
-static void ggml_compute_forward_conv_2d(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                //ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
-                GGML_ASSERT(false);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
-static void ggml_compute_forward_conv_2d_stage_0(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                GGML_ASSERT(false);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
-static void ggml_compute_forward_conv_2d_stage_1(
+static void ggml_compute_forward_im2col(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -12566,7 +11801,7 @@ static void ggml_compute_forward_conv_2d_stage_1(
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_conv_2d_stage_1_f16(params, src0, src1, dst);
+                ggml_compute_forward_im2col_f16(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F32:
             {
@@ -14783,33 +14018,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_clamp(params, tensor->src[0], tensor);
             } break;
-        case GGML_OP_CONV_1D:
-            {
-                ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
-            } break;
-        case GGML_OP_CONV_1D_STAGE_0:
-            {
-                ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
-            } break;
-        case GGML_OP_CONV_1D_STAGE_1:
-            {
-                ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
-            } break;
         case GGML_OP_CONV_TRANSPOSE_1D:
             {
                 ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
             } break;
-        case GGML_OP_CONV_2D:
-            {
-                ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
-            } break;
-        case GGML_OP_CONV_2D_STAGE_0:
-            {
-                ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
-            } break;
-        case GGML_OP_CONV_2D_STAGE_1:
+        case GGML_OP_IM2COL:
             {
-                ggml_compute_forward_conv_2d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_im2col(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_CONV_TRANSPOSE_2D:
             {
@@ -15780,31 +14995,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
-        case GGML_OP_CONV_1D:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_CONV_1D_STAGE_0:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_CONV_1D_STAGE_1:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
         case GGML_OP_CONV_TRANSPOSE_1D:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
-        case GGML_OP_CONV_2D:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_CONV_2D_STAGE_0:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_CONV_2D_STAGE_1:
+        case GGML_OP_IM2COL:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
@@ -16533,31 +15728,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = 1; //TODO
             } break;
-        case GGML_OP_CONV_1D:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_CONV_1D_STAGE_0:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_CONV_1D_STAGE_1:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_CONV_TRANSPOSE_1D:
             {
                 n_tasks = n_threads;
             } break;
-        case GGML_OP_CONV_2D:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_CONV_2D_STAGE_0:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_CONV_2D_STAGE_1:
+        case GGML_OP_IM2COL:
             {
                 n_tasks = n_threads;
             } break;
@@ -16642,6 +15817,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         default:
             {
+                printf("%s: op %s not implemented\n", __func__, ggml_op_name(node->op));
                 GGML_ASSERT(false);
             } break;
     }
@@ -16844,38 +16020,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
                     }
                 } break;
-            case GGML_OP_CONV_1D:
-                {
-                    GGML_ASSERT(node->src[0]->ne[3] == 1);
-                    GGML_ASSERT(node->src[1]->ne[2] == 1);
-                    GGML_ASSERT(node->src[1]->ne[3] == 1);
-
-                    const int64_t ne00 = node->src[0]->ne[0];
-                    const int64_t ne01 = node->src[0]->ne[1];
-                    const int64_t ne02 = node->src[0]->ne[2];
-
-                    const int64_t ne10 = node->src[1]->ne[0];
-                    const int64_t ne11 = node->src[1]->ne[1];
-
-                    const int64_t ne0 = node->ne[0];
-                    const int64_t ne1 = node->ne[1];
-                    const int64_t nk  = ne00;
-                    const int64_t ew0 = nk * ne01;
-
-                    UNUSED(ne02);
-                    UNUSED(ne10);
-                    UNUSED(ne11);
-
-                    if (node->src[0]->type == GGML_TYPE_F16 &&
-                        node->src[1]->type == GGML_TYPE_F32) {
-                        cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
-                    } else if (node->src[0]->type == GGML_TYPE_F32 &&
-                               node->src[1]->type == GGML_TYPE_F32) {
-                        cur = sizeof(float)*(ne0*ne1*ew0);
-                    } else {
-                        GGML_ASSERT(false);
-                    }
-                } break;
             case GGML_OP_CONV_TRANSPOSE_1D:
                 {
                     GGML_ASSERT(node->src[0]->ne[3] == 1);
@@ -16901,37 +16045,9 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         GGML_ASSERT(false);
                     }
                 } break;
-            case GGML_OP_CONV_2D:
+            case GGML_OP_IM2COL:
                 {
-                    const int64_t ne00 = node->src[0]->ne[0]; // W
-                    const int64_t ne01 = node->src[0]->ne[1]; // H
-                    const int64_t ne02 = node->src[0]->ne[2]; // C
-                    const int64_t ne03 = node->src[0]->ne[3]; // N
-
-                    const int64_t ne10 = node->src[1]->ne[0]; // W
-                    const int64_t ne11 = node->src[1]->ne[1]; // H
-                    const int64_t ne12 = node->src[1]->ne[2]; // C
-
-                    const int64_t ne0 = node->ne[0];
-                    const int64_t ne1 = node->ne[1];
-                    const int64_t ne2 = node->ne[2];
-                    const int64_t ne3 = node->ne[3];
-                    const int64_t nk = ne00*ne01;
-                    const int64_t ew0 = nk * ne02;
-
-                    UNUSED(ne03);
-                    UNUSED(ne2);
-
-                    if (node->src[0]->type == GGML_TYPE_F16 &&
-                        node->src[1]->type == GGML_TYPE_F32) {
-                        // im2col: [N*OH*OW, IC*KH*KW]
-                        cur = sizeof(ggml_fp16_t)*(ne3*ne0*ne1*ew0);
-                    } else if (node->src[0]->type == GGML_TYPE_F32 &&
-                               node->src[1]->type == GGML_TYPE_F32) {
-                        cur = sizeof(float)*      (ne10*ne11*ne12);
-                    } else {
-                        GGML_ASSERT(false);
-                    }
+                    n_tasks = n_threads;
                 } break;
             case GGML_OP_CONV_TRANSPOSE_2D:
                 {
diff --git a/ggml.h b/ggml.h
index e56a8337a507747eae914cd355c08d791c68ba4c..52ae6755a89ad2e8746642233fea80a1d149937a 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -403,13 +403,8 @@ extern "C" {
         GGML_OP_ROPE_BACK,
         GGML_OP_ALIBI,
         GGML_OP_CLAMP,
-        GGML_OP_CONV_1D,
-        GGML_OP_CONV_1D_STAGE_0,  // internal
-        GGML_OP_CONV_1D_STAGE_1,  // internal
         GGML_OP_CONV_TRANSPOSE_1D,
-        GGML_OP_CONV_2D,
-        GGML_OP_CONV_2D_STAGE_0, // internal
-        GGML_OP_CONV_2D_STAGE_1, // internal
+        GGML_OP_IM2COL,
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
@@ -1398,6 +1393,18 @@ extern "C" {
             float                 min,
             float                 max);
 
+    GGML_API struct ggml_tensor * ggml_im2col(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                  s0,
+            int                  s1,
+            int                  p0,
+            int                  p1,
+            int                  d0,
+            int                  d1,
+            bool                 is_2D);
+
     GGML_API struct ggml_tensor * ggml_conv_1d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index 681727f5cd16b85f7c3f737cd6c4034052e92967..244cfeb102c57756e186d1ee12b6815149832f30 100644 (file)
@@ -1,10 +1,15 @@
 #include "whisper.h"
+
 #ifdef WHISPER_USE_COREML
 #include "coreml/whisper-encoder.h"
 #endif
 
 #ifdef GGML_USE_METAL
-#  include "ggml-metal.h"
+#include "ggml-metal.h"
+#endif
+
+#ifdef GGML_USE_CUBLAS
+#include "ggml-cuda.h"
 #endif
 
 #ifdef WHISPER_USE_OPENVINO
@@ -13,6 +18,7 @@
 
 #include "ggml.h"
 #include "ggml-alloc.h"
+#include "ggml-backend.h"
 
 #include <algorithm>
 #include <cassert>
@@ -97,10 +103,32 @@ static void byteswap_tensor(ggml_tensor * tensor) {
 #define BYTESWAP_TENSOR(t) do {} while (0)
 #endif
 
+#ifdef __GNUC__
+#ifdef __MINGW32__
+#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
+#else
+#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
+#endif
+#else
+#define WHISPER_ATTRIBUTE_FORMAT(...)
+#endif
+
+//
+// logging
+//
+
+WHISPER_ATTRIBUTE_FORMAT(2, 3)
+static void whisper_log_internal        (ggml_log_level level, const char* format, ...);
+static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data);
+
+#define WHISPER_LOG_INFO(...)  whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
+#define WHISPER_LOG_WARN(...)  whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
+#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+
 #define WHISPER_ASSERT(x) \
     do { \
         if (!(x)) { \
-            log("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+            WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
             abort(); \
         } \
     } while (0)
@@ -127,8 +155,8 @@ static void byteswap_tensor(ggml_tensor * tensor) {
 //
 
 static void ggml_graph_compute_helper(
+          struct ggml_cgraph * graph,
         std::vector<uint8_t> & buf,
-                 ggml_cgraph * graph,
                          int   n_threads,
       whisper_abort_callback   abort_callback,
                         void * abort_callback_data) {
@@ -145,6 +173,21 @@ static void ggml_graph_compute_helper(
     ggml_graph_compute(graph, &plan);
 }
 
+static void ggml_graph_compute_helper(
+       struct ggml_backend * backend,
+        struct ggml_cgraph * graph,
+                       int   n_threads) {
+    if (ggml_backend_is_cpu(backend)) {
+        ggml_backend_cpu_set_n_threads(backend, n_threads);
+    }
+#ifdef GGML_USE_METAL
+    if (ggml_backend_is_metal(backend)) {
+        ggml_backend_metal_set_n_cb(backend, n_threads);
+    }
+#endif
+    ggml_backend_graph_compute(backend, graph);
+}
+
 // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
 // the idea is to represent the original matrix multiplication:
 //
@@ -179,6 +222,7 @@ static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct g
 }
 
 // TODO: check if other platforms can benefit from this optimization
+// TODO: CUDA is currently broken - seems ggml_mul_mat does not handle views correctly
 #if defined(GGML_USE_METAL)
 #define ggml_mul_mat ggml_mul_mat_pad
 #endif
@@ -305,75 +349,6 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
     { "yue", { 99,  "cantonese",      } },
 };
 
-static const size_t MB = 1ull*1024*1024;
-
-// TODO: avoid using GGUF
-static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
-    { GGML_TYPE_F32,
-        {
-            { MODEL_TINY,     74ull*MB },
-            { MODEL_BASE,    142ull*MB },
-            { MODEL_SMALL,   466ull*MB },
-            { MODEL_MEDIUM, 1464ull*MB },
-            { MODEL_LARGE,  2952ull*MB },
-        },
-    },
-    { GGML_TYPE_F16,
-        {
-            { MODEL_TINY,     74ull*MB },
-            { MODEL_BASE,    142ull*MB },
-            { MODEL_SMALL,   466ull*MB },
-            { MODEL_MEDIUM, 1464ull*MB },
-            { MODEL_LARGE,  2952ull*MB },
-        },
-    },
-    { GGML_TYPE_Q4_0,
-        {
-            { MODEL_TINY,     26ull*MB },
-            { MODEL_BASE,     50ull*MB },
-            { MODEL_SMALL,   154ull*MB },
-            { MODEL_MEDIUM,  470ull*MB },
-            { MODEL_LARGE,   940ull*MB },
-        },
-    },
-    { GGML_TYPE_Q4_1,
-        {
-            { MODEL_TINY,     32ull*MB },
-            { MODEL_BASE,     58ull*MB },
-            { MODEL_SMALL,   182ull*MB },
-            { MODEL_MEDIUM,  562ull*MB },
-            { MODEL_LARGE,  1124ull*MB },
-        },
-    },
-    { GGML_TYPE_Q5_0,
-        {
-            { MODEL_TINY,     30ull*MB },
-            { MODEL_BASE,     54ull*MB },
-            { MODEL_SMALL,   170ull*MB },
-            { MODEL_MEDIUM,  516ull*MB },
-            { MODEL_LARGE,  1034ull*MB },
-        },
-    },
-    { GGML_TYPE_Q5_1,
-        {
-            { MODEL_TINY,     32ull*MB },
-            { MODEL_BASE,     58ull*MB },
-            { MODEL_SMALL,   182ull*MB },
-            { MODEL_MEDIUM,  562ull*MB },
-            { MODEL_LARGE,  1124ull*MB },
-        },
-    },
-    { GGML_TYPE_Q8_0,
-        {
-            { MODEL_TINY,     45ull*MB },
-            { MODEL_BASE,     84ull*MB },
-            { MODEL_SMALL,   268ull*MB },
-            { MODEL_MEDIUM,  834ull*MB },
-            { MODEL_LARGE,  1674ull*MB },
-        },
-    },
-};
-
 struct whisper_mel {
     int n_len;
     int n_len_org;
@@ -554,8 +529,7 @@ struct whisper_kv_cache {
 
     struct ggml_context * ctx;
 
-    // buf points to the memory allocated for both ggml_tensor 'k' and 'v' (see kv_cache_init)
-    std::vector<uint8_t> buf;
+    ggml_backend_buffer_t buffer;
 
     int n; // number of tokens currently in the cache
 };
@@ -594,11 +568,11 @@ struct whisper_model {
     std::vector<whisper_layer_encoder> layers_encoder;
     std::vector<whisper_layer_decoder> layers_decoder;
 
-    // context
+    // ggml context that contains all the meta information about the model tensors
     struct ggml_context * ctx;
 
-    // the model memory buffer is read-only and can be shared between processors
-    std::vector<uint8_t> * buf;
+    // the model backend data is read-only and can be shared between processors
+    struct ggml_backend_buffer * buffer;
 
     // tensors
     int n_loaded;
@@ -663,37 +637,47 @@ struct whisper_allocr {
     ggml_allocr * alloc = nullptr;
 
     std::vector<uint8_t> meta;
-    std::vector<uint8_t> data;
+
+    ggml_backend_buffer_t buffer;
 };
 
 static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
-    return allocr.meta.size() + allocr.data.size();
+    return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
 }
 
 // measure the memory usage of a graph and prepare the allocr's internal data buffer
-static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct ggml_cgraph *()> && get_graph) {
-    const int tensor_alignment = 32;
+static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
+    auto & alloc  = allocr.alloc;
+    auto & meta   = allocr.meta;
 
-    auto & alloc = allocr.alloc;
-    auto & meta  = allocr.meta;
-    auto & data  = allocr.data;
+    alloc = ggml_allocr_new_measure_from_backend(backend);
 
     meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
 
-    alloc = ggml_allocr_new_measure(tensor_alignment);
+    ggml_allocr_alloc_graph(alloc, get_graph());
+}
 
-    const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
+static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
+    if (allocr.alloc == nullptr) {
+        // this can be null if we use external encoder like CoreML or OpenVINO
+        return;
+    }
 
-    ggml_allocr_free(alloc);
+    auto & alloc  = allocr.alloc;
+    auto & buffer = allocr.buffer;
 
-    data.resize(alloc_size);
+    size_t size = ggml_allocr_max_size(alloc);
 
-    alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
+    ggml_allocr_free(alloc);
+
+    buffer = ggml_backend_alloc_buffer(backend, size);
+    alloc = ggml_allocr_new_from_buffer(buffer);
 }
 
 static void whisper_allocr_free(struct whisper_allocr & allocr) {
     if (allocr.alloc) {
         ggml_allocr_free(allocr.alloc);
+        ggml_backend_buffer_free(allocr.buffer);
         allocr.alloc = nullptr;
     }
 }
@@ -722,8 +706,7 @@ struct whisper_state {
     // buffer for swapping KV caches between decoders during beam-search
     std::vector<kv_buf> kv_swap_bufs;
 
-    // reusable buffer for `struct ggml_graph_plan.work_data`
-    std::vector<uint8_t> work_buffer;
+    ggml_backend_t backend = nullptr;
 
     // ggml-alloc:
     // - stores meta info about the intermediate tensors into the `meta` buffers
@@ -737,6 +720,9 @@ struct whisper_state {
     struct ggml_tensor * embd_conv = nullptr;
     struct ggml_tensor * embd_enc  = nullptr;
 
+    // helper for GPU offloading
+    std::vector<float> inp_mel;
+
     // decode output (2-dimensional array: [n_tokens][n_vocab])
     std::vector<float> logits;
 
@@ -751,22 +737,21 @@ struct whisper_state {
     int lang_id = 0; // english by default
 
     std::string path_model; // populated by whisper_init_from_file_with_params()
+
 #ifdef WHISPER_USE_COREML
     whisper_coreml_context * ctx_coreml = nullptr;
 #endif
 
-#ifdef GGML_USE_METAL
-    ggml_metal_context * ctx_metal = nullptr;
-#endif
-
 #ifdef WHISPER_USE_OPENVINO
     whisper_openvino_context * ctx_openvino = nullptr;
 #endif
 
     // [EXPERIMENTAL] token-level timestamps data
-    int64_t t_beg = 0;
+    int64_t t_beg  = 0;
     int64_t t_last = 0;
+
     whisper_token tid_last;
+
     std::vector<float> energy; // PCM signal energy
 
     // [EXPERIMENTAL] speed-up techniques
@@ -780,35 +765,25 @@ struct whisper_context {
     ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
     ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16)
 
+    whisper_context_params params;
+
     whisper_model model;
     whisper_vocab vocab;
+
     whisper_state * state = nullptr;
 
+    ggml_backend_t backend = nullptr;
+
     std::string path_model; // populated by whisper_init_from_file_with_params()
-    whisper_context_params params;
 };
 
-static void whisper_default_log(const char * text) {
-    fprintf(stderr, "%s", text);
-}
+struct whisper_global {
+    // We save the log callback globally
+    ggml_log_callback log_callback = whisper_log_callback_default;
+    void * log_callback_user_data = nullptr;
+};
 
-static whisper_log_callback whisper_log = whisper_default_log;
-
-#ifdef __GNUC__
-#ifdef __MINGW32__
-__attribute__((gnu_format(printf, 1, 2)))
-#else
-__attribute__((format(printf, 1, 2)))
-#endif
-#endif
-static void log(const char * fmt, ...) {
-    if (!whisper_log) return;
-    char buf[1024];
-    va_list args;
-    va_start(args, fmt);
-    vsnprintf(buf, sizeof(buf), fmt, args);
-    whisper_log(buf);
-}
+static whisper_global g_state;
 
 template<typename T>
 static void read_safe(whisper_model_loader * loader, T & dest) {
@@ -819,6 +794,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
 static bool kv_cache_init(
         const struct whisper_hparams & hparams,
              struct whisper_kv_cache & cache,
+                      ggml_backend_t   backend,
                            ggml_type   wtype,
                                  int   n_ctx) {
     const int64_t n_text_state = hparams.n_text_state;
@@ -827,30 +803,41 @@ static bool kv_cache_init(
     const int64_t n_mem      = n_text_layer*n_ctx;
     const int64_t n_elements = n_text_state*n_mem;
 
-    const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead());
-
-    cache.buf.resize(mem_bytes);
-
     struct ggml_init_params params = {
-        /*.mem_size   =*/ cache.buf.size(),
-        /*.mem_buffer =*/ cache.buf.data(),
-        /*.no_alloc   =*/ false,
+        /*.mem_size   =*/ 2*ggml_tensor_overhead(),
+        /*.mem_buffer =*/ nullptr,
+        /*.no_alloc   =*/ true,
     };
 
     cache.ctx = ggml_init(params);
 
     if (!cache.ctx) {
-        log("%s: failed to allocate memory for kv cache\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
         return false;
     }
 
     cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
     cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
 
+    const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v);
+
+    cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes);
+
+    // allocate the tensors into the backend buffer
+    {
+        ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer);
+
+        ggml_allocr_alloc(alloc, cache.k);
+        ggml_allocr_alloc(alloc, cache.v);
+
+        ggml_allocr_free(alloc);
+    }
+
     return true;
 }
 
-static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
+// TODO: remove after batched decoding
+static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) {
     WHISPER_ASSERT(cache.ctx);
 
     const int n_elements = ggml_nelements(cache.k);
@@ -859,34 +846,78 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
     const ggml_type wtype = cache.k->type;
     WHISPER_ASSERT(wtype == cache.v->type);
 
-    WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype));
-
     struct ggml_init_params params = {
-        /*.mem_size   =*/ cache.buf.size(),
-        /*.mem_buffer =*/ cache.buf.data(),
-        /*.no_alloc   =*/ false,
+        /*.mem_size   =*/ 2*ggml_tensor_overhead(),
+        /*.mem_buffer =*/ nullptr,
+        /*.no_alloc   =*/ true,
     };
 
     cache.ctx = ggml_init(params);
 
     if (!cache.ctx) {
-        log("%s: failed to allocate memory for kv cache\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
         return false;
     }
 
     cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
     cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
 
+    const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v);
+
+    cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes);
+
+    // allocate the tensors into the backend buffer
+    {
+        ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer);
+
+        ggml_allocr_alloc(alloc, cache.k);
+        ggml_allocr_alloc(alloc, cache.v);
+
+        ggml_allocr_free(alloc);
+    }
+
     return true;
 }
 
 static void kv_cache_free(struct whisper_kv_cache & cache) {
     if (cache.ctx) {
         ggml_free(cache.ctx);
+        ggml_backend_buffer_free(cache.buffer);
         cache.ctx = nullptr;
     }
 }
 
+static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
+    ggml_backend_t backend_gpu = NULL;
+
+    // initialize the backends
+#ifdef GGML_USE_CUBLAS
+    if (params.use_gpu) {
+        WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
+        backend_gpu = ggml_backend_cuda_init();
+        if (!backend_gpu) {
+            WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
+        }
+    }
+#endif
+
+#ifdef GGML_USE_METAL
+    if (params.use_gpu) {
+        WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
+        ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
+        backend_gpu = ggml_backend_metal_init();
+        if (!backend_gpu) {
+            WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
+        }
+    }
+#endif
+
+    if (backend_gpu) {
+        return backend_gpu;
+    }
+    return ggml_backend_cpu_init();
+}
+
 // load the model from a ggml file
 //
 // file format:
@@ -899,7 +930,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
 // see the convert-pt-to-ggml.py script for details
 //
 static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
-    log("%s: loading model\n", __func__);
+    WHISPER_LOG_INFO("%s: loading model\n", __func__);
 
     const int64_t t_start_us = ggml_time_us();
 
@@ -913,7 +944,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         uint32_t magic;
         read_safe(loader, magic);
         if (magic != GGML_FILE_MAGIC) {
-            log("%s: invalid model data (bad magic)\n", __func__);
+            WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
             return false;
         }
     }
@@ -970,41 +1001,23 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         // in order to save memory and also to speed up the computation
         wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
         if (wctx.wtype == GGML_TYPE_COUNT) {
-            log("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
+            WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
             return false;
         }
 
-        const size_t scale = model.hparams.ftype ? 1 : 2;
-
-        log("%s: n_vocab       = %d\n", __func__, hparams.n_vocab);
-        log("%s: n_audio_ctx   = %d\n", __func__, hparams.n_audio_ctx);
-        log("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
-        log("%s: n_audio_head  = %d\n", __func__, hparams.n_audio_head);
-        log("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
-        log("%s: n_text_ctx    = %d\n", __func__, hparams.n_text_ctx);
-        log("%s: n_text_state  = %d\n", __func__, hparams.n_text_state);
-        log("%s: n_text_head   = %d\n", __func__, hparams.n_text_head);
-        log("%s: n_text_layer  = %d\n", __func__, hparams.n_text_layer);
-        log("%s: n_mels        = %d\n", __func__, hparams.n_mels);
-        log("%s: ftype         = %d\n", __func__, model.hparams.ftype);
-        log("%s: qntvr         = %d\n", __func__, qntvr);
-        log("%s: type          = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
-
-        // print memory requirements
-        {
-            // TODO
-            //log("%s: mem required  = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
-            //        mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
-        }
-
-        // initialize all memory buffers
-        // always have at least one decoder
-
-        wctx.model.buf = new std::vector<uint8_t>();
-        wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
-
-        // we skip initialization of the state until it is needed
-        // because it might be that state will always be provided externally.
+        WHISPER_LOG_INFO("%s: n_vocab       = %d\n", __func__, hparams.n_vocab);
+        WHISPER_LOG_INFO("%s: n_audio_ctx   = %d\n", __func__, hparams.n_audio_ctx);
+        WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
+        WHISPER_LOG_INFO("%s: n_audio_head  = %d\n", __func__, hparams.n_audio_head);
+        WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
+        WHISPER_LOG_INFO("%s: n_text_ctx    = %d\n", __func__, hparams.n_text_ctx);
+        WHISPER_LOG_INFO("%s: n_text_state  = %d\n", __func__, hparams.n_text_state);
+        WHISPER_LOG_INFO("%s: n_text_head   = %d\n", __func__, hparams.n_text_head);
+        WHISPER_LOG_INFO("%s: n_text_layer  = %d\n", __func__, hparams.n_text_layer);
+        WHISPER_LOG_INFO("%s: n_mels        = %d\n", __func__, hparams.n_mels);
+        WHISPER_LOG_INFO("%s: ftype         = %d\n", __func__, model.hparams.ftype);
+        WHISPER_LOG_INFO("%s: qntvr         = %d\n", __func__, qntvr);
+        WHISPER_LOG_INFO("%s: type          = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
     }
 
     // load mel filters
@@ -1025,7 +1038,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         read_safe(loader, n_vocab);
 
         //if (n_vocab != model.hparams.n_vocab) {
-        //    log("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+        //    WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
         //            __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
         //    return false;
         //}
@@ -1045,7 +1058,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 word.assign(&tmp[0], tmp.size());
             } else {
                 // seems like we have an empty-string token in multi-language models (i = 50256)
-                //log("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
+                //WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
                 word = "";
             }
 
@@ -1073,7 +1086,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
 
         if (n_vocab < model.hparams.n_vocab) {
-            log("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
+            WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
             for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
                 if (i > vocab.token_beg) {
                     word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
@@ -1099,140 +1112,35 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             }
         }
 
-        log("%s: n_langs       = %d\n", __func__, vocab.num_languages());
+        WHISPER_LOG_INFO("%s: n_langs       = %d\n", __func__, vocab.num_languages());
     }
 
-    size_t ctx_size = 0;
-
     const ggml_type wtype = wctx.wtype;
     const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
 
+    // create the ggml context
     {
         const auto & hparams = model.hparams;
 
-        const int n_vocab = hparams.n_vocab;
-
-        const int n_audio_ctx   = hparams.n_audio_ctx;
-        const int n_audio_state = hparams.n_audio_state;
         const int n_audio_layer = hparams.n_audio_layer;
+        const int n_text_layer  = hparams.n_text_layer;
 
-        const int n_text_ctx   = hparams.n_text_ctx;
-        const int n_text_state = hparams.n_text_state;
-        const int n_text_layer = hparams.n_text_layer;
-
-        const int n_mels = hparams.n_mels;
-
-        // encoder
-        {
-            ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe;
-
-            ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(vtype);         // e_conv_1_w
-            ctx_size +=          n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b
-
-            ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(vtype);         // e_conv_2_w
-            ctx_size +=                 n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b
-
-            ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w;
-            ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b;
-        }
-
-        // decoder
-        {
-            ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe;
-
-            ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te;
-
-            ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w;
-            ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b;
-        }
-
-        // encoder layers
-        {
-            ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
-            ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
-
-            ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype));         // mlp_0_w
-            ctx_size += n_audio_layer*(              4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
-
-            ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype));         // mlp_1_w
-            ctx_size += n_audio_layer*(                n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
-
-            ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
-            ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
-
-            ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype));         // attn_q_w
-            ctx_size += n_audio_layer*(              n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
-
-            ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w
-
-            ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype));         // attn_v_w
-            ctx_size += n_audio_layer*(              n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
-
-            ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype));         // attn_ln_1_w
-            ctx_size += n_audio_layer*(              n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
-        }
-
-        // decoder layers
-        {
-            ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
-            ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
-
-            ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype));         // mlp_0_w
-            ctx_size += n_text_layer*(             4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
-
-            ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype));         // mlp_1_w
-            ctx_size += n_text_layer*(               n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
-
-            ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
-            ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype));         // attn_q_w
-            ctx_size += n_text_layer*(             n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype));         // attn_v_w
-            ctx_size += n_text_layer*(             n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype));         // attn_ln_1_w
-            ctx_size += n_text_layer*(             n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
-                                                                                                //
-            ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w
-            ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype));         // cross_attn_q_w
-            ctx_size += n_text_layer*(             n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype));         // cross_attn_v_w
-            ctx_size += n_text_layer*(             n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype));         // cross_attn_ln_1_w
-            ctx_size += n_text_layer*(             n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b
-        }
-
-        ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead
-
-        log("%s: model ctx     = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
-    }
+        const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
 
-    // create the ggml context
-    {
         struct ggml_init_params params = {
-            /*.mem_size   =*/ wctx.model.buf->size(),
-            /*.mem_buffer =*/ wctx.model.buf->data(),
-            /*.no_alloc   =*/ false,
+            /*.mem_size   =*/ n_tensors*ggml_tensor_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
         };
 
         model.ctx = ggml_init(params);
         if (!model.ctx) {
-            log("%s: ggml_init() failed\n", __func__);
+            WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__);
             return false;
         }
     }
 
-    // prepare memory for the weights
+    // prepare tensors for the weights
     {
         auto & ctx = model.ctx;
 
@@ -1255,16 +1163,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
         // encoder
         {
-            model.e_pe       = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
+            model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
 
-            model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype,         3, n_mels, n_audio_state);
-            model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
+            model.e_conv_1_w     = ggml_new_tensor_3d(ctx, vtype,         3, n_mels,     n_audio_state);
+            model.e_conv_1_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state);
 
-            model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype,         3, n_audio_state, n_audio_state);
-            model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
+            model.e_conv_2_w     = ggml_new_tensor_3d(ctx, vtype,         3, n_audio_state, n_audio_state);
+            model.e_conv_2_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,    n_audio_ctx,   n_audio_state);
 
-            model.e_ln_w     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
-            model.e_ln_b     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+            model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+            model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
 
             // map by name
             model.tensors["encoder.positional_embedding"] = model.e_pe;
@@ -1428,12 +1336,37 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
     }
 
+    wctx.backend = whisper_backend_init(wctx.params);
+
+    {
+        size_t size_main = 0;
+
+        for (const auto & t : model.tensors) {
+            size_main += ggml_nbytes(t.second) + ggml_tensor_overhead();
+        }
+
+        model.buffer = ggml_backend_alloc_buffer(wctx.backend, size_main);
+
+        WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1024.0 / 1024.0);
+    }
+
+    ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer);
+
+    // allocate tensors in the backend buffers
+    {
+        for (const auto & t : model.tensors) {
+            ggml_allocr_alloc(alloc, t.second);
+        }
+    }
+
     // load weights
     {
         size_t total_size = 0;
 
         model.n_loaded = 0;
 
+        std::vector<char> read_buf;
+
         while (true) {
             int32_t n_dims;
             int32_t length;
@@ -1460,50 +1393,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             name.assign(&tmp[0], tmp.size());
 
             if (model.tensors.find(name) == model.tensors.end()) {
-                log("%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
                 return false;
             }
 
             auto tensor = model.tensors[name.data()];
-            if (ggml_nelements(tensor) != nelements) {
-                log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
-                log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
-                        __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
-                return false;
-            }
 
-            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
-                log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
-                        __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
-                return false;
-            }
+            const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias");
 
-            const size_t bpe = ggml_type_size(ggml_type(ttype));
+            if (!is_conv_bias) {
+                if (ggml_nelements(tensor) != nelements) {
+                    WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                    WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
+                            __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
+                    return false;
+                }
 
-            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
-                log("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
-                        __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
-                return false;
+                if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
+                    WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
+                            __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
+                    return false;
+                }
+
+                const size_t bpe = ggml_type_size(ggml_type(ttype));
+
+                if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                    WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                            __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+                    return false;
+                }
             }
 
-            loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
-            BYTESWAP_TENSOR(tensor);
+            ggml_backend_t backend = wctx.backend;
+
+            //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
+
+            if ((ggml_backend_is_cpu(backend)
+#ifdef GGML_USE_METAL
+                || ggml_backend_is_metal(backend)
+#endif
+                ) && !is_conv_bias) {
+                // for the CPU and Metal backend, we can read directly into the tensor
+                loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
+                BYTESWAP_TENSOR(tensor);
+            } else {
+                // read into a temporary buffer first, then copy to device memory
+                read_buf.resize(ggml_nbytes(tensor));
+
+                // we repeat the 2 bias tensors along dim 0:
+                // [1, 512] -> [3000, 512] (conv1.bias)
+                // [1, 512] -> [1500, 512] (conv2.bias)
+                if (is_conv_bias) {
+                    loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
+
+                    float * data_f32 = (float *) read_buf.data();
+                    for (int64_t y = 0; y < tensor->ne[1]; ++y) {
+                        const int64_t yy = tensor->ne[1] - y - 1;
+                        const float val = data_f32[yy];
+
+                        for (int64_t x = 0; x < tensor->ne[0]; ++x) {
+                            data_f32[yy*tensor->ne[0] + x] = val;
+                        }
+                    }
+                } else {
+                    loader->read(loader->context, read_buf.data(), read_buf.size());
+                }
+
+                ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
+            }
 
             //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0);
             total_size += ggml_nbytes(tensor);
             model.n_loaded++;
         }
 
-        log("%s: model size    = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
+        WHISPER_LOG_INFO("%s: model size    = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
 
         if (model.n_loaded == 0) {
-            log("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
+            WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
         } else if (model.n_loaded != (int) model.tensors.size()) {
-            log("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
+            WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
             return false;
         }
     }
 
+    ggml_allocr_free(alloc);
+
     wctx.t_load_us = ggml_time_us() - t_start_us;
 
     return true;
@@ -1559,10 +1534,12 @@ static struct ggml_cgraph * whisper_build_graph_conv(
     if (!ggml_allocr_is_measure(alloc)) {
         assert(mel_inp.n_mel == n_mels);
 
-        float * dst = (float *) mel->data;
+        wstate.inp_mel.resize(ggml_nelements(mel));
+
+        float * dst = wstate.inp_mel.data();
         memset(dst, 0, ggml_nbytes(mel));
 
-        const int i0 = std::min(mel_offset, mel_inp.n_len);
+        const int i0 = std::min(mel_offset,           mel_inp.n_len);
         const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
 
         for (int j = 0; j < mel_inp.n_mel; ++j) {
@@ -1570,6 +1547,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
                 dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
             }
         }
+
+        ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
     }
 
     struct ggml_tensor * cur = nullptr;
@@ -1578,24 +1557,27 @@ static struct ggml_cgraph * whisper_build_graph_conv(
         // convolution + gelu
         {
             cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
-            cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        model.e_conv_1_b,
-                        cur),
-                    cur);
+            cur = ggml_add(ctx0, cur, model.e_conv_1_b);
+            //cur = ggml_add(ctx0,
+            //        ggml_repeat(ctx0,
+            //            model.e_conv_1_b,
+            //            cur),
+            //        cur);
 
             cur = ggml_gelu(ctx0, cur);
 
             cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
-            cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        model.e_conv_2_b,
-                        cur),
-                    cur);
+            cur = ggml_add(ctx0, cur, model.e_conv_2_b);
+            //cur = ggml_add(ctx0,
+            //        ggml_repeat(ctx0,
+            //            model.e_conv_2_b,
+            //            cur),
+            //        cur);
 
             cur = ggml_gelu(ctx0, cur);
         }
 
+        ggml_set_name(cur, "embd_conv");
         wstate.embd_conv = cur;
     } else {
 #ifdef WHISPER_USE_COREML
@@ -1615,6 +1597,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
         }
 #endif
 
+        ggml_set_name(cur, "embd_enc");
         wstate.embd_enc = cur;
     }
 
@@ -1648,15 +1631,22 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
 
     ggml_allocr * alloc = wstate.alloc_encode.alloc;
 
+    //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state);
+    //ggml_allocr_alloc(alloc, cur);
+
+    //if (!ggml_allocr_is_measure(alloc)) {
+    //    ggml_backend_tensor_copy(wstate.embd_conv, cur);
+    //}
+    struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
+
     struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
     ggml_allocr_alloc(alloc, KQscale);
 
     if (!ggml_allocr_is_measure(alloc)) {
-        ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
+        const float val = 1.0f/sqrtf(float(n_state)/n_head);
+        ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
     }
 
-    struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
-
     // ===================================================================
     // NOTE: experimenting with partial evaluation of the encoder (ignore)
     //static int iter = -1;
@@ -1675,7 +1665,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
     const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
 
     struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
-
     cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
 
     // ===================================================================
@@ -1897,13 +1886,20 @@ static struct ggml_cgraph * whisper_build_graph_cross(
 
     ggml_allocr * alloc = wstate.alloc_cross.alloc;
 
+    //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
+    //ggml_allocr_alloc(alloc, cur);
+
+    //if (!ggml_allocr_is_measure(alloc)) {
+    //    ggml_backend_tensor_copy(wstate.embd_enc, cur);
+    //}
     struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
 
     struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
     ggml_allocr_alloc(alloc, Kscale);
 
     if (!ggml_allocr_is_measure(alloc)) {
-        ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
+        const float val = pow(float(n_state) / n_head, -0.25);
+        ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float));
     }
 
     for (int il = 0; il < model.hparams.n_text_layer; ++il) {
@@ -1974,7 +1970,7 @@ static bool whisper_encode_internal(
         ggml_allocr_alloc_graph(alloc, gf);
 
         if (!whisper_encode_external(wstate)) {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
+            ggml_graph_compute_helper(wstate.backend, gf, n_threads);
         }
     }
 
@@ -1988,16 +1984,7 @@ static bool whisper_encode_internal(
 
         ggml_allocr_alloc_graph(alloc, gf);
 
-#ifdef GGML_USE_METAL
-        if (wstate.ctx_metal) {
-            ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
-            ggml_metal_graph_compute(wstate.ctx_metal, gf);
-        } else {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
-        }
-#else
-        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
-#endif
+        ggml_graph_compute_helper(wstate.backend, gf, n_threads);
     }
 
     // cross
@@ -2010,20 +1997,9 @@ static bool whisper_encode_internal(
 
         ggml_allocr_alloc_graph(alloc, gf);
 
-#ifdef GGML_USE_METAL
-        if (wstate.ctx_metal) {
-            ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
-            ggml_metal_graph_compute(wstate.ctx_metal, gf);
-        } else {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
-        }
-#else
-        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
-#endif
+        ggml_graph_compute_helper(wstate.backend, gf, n_threads);
     }
 
-    // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
-
     wstate.t_encode_us += ggml_time_us() - t_start_us;
     wstate.n_encode++;
 
@@ -2070,7 +2046,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
     ggml_allocr_alloc(alloc, embd);
 
     if (!ggml_allocr_is_measure(alloc)) {
-        memcpy(embd->data, tokens, N*ggml_element_size(embd));
+        ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd));
     }
 
     struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
@@ -2078,7 +2054,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
     if (!ggml_allocr_is_measure(alloc)) {
         for (int i = 0; i < N; ++i) {
-            ((int32_t *) position->data)[i] = n_past + i;
+            const int32_t val = n_past + i;
+            ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
         }
     }
 
@@ -2086,7 +2063,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
     ggml_allocr_alloc(alloc, KQscale);
 
     if (!ggml_allocr_is_measure(alloc)) {
-        ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
+        const float val = pow(float(n_state)/n_head, -0.25);
+        ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
     }
 
     // token encoding + position encoding
@@ -2410,25 +2388,18 @@ static bool whisper_decode_internal(
 
         logits = gf->nodes[gf->n_nodes - 1];
 
-#ifdef GGML_USE_METAL
-        if (wstate.ctx_metal) {
-            ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
-            ggml_metal_graph_compute(wstate.ctx_metal, gf);
-        } else {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
-        }
-#else
-        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
-#endif
+        ggml_graph_compute_helper(wstate.backend, gf, n_threads);
     }
 
     // extract logits for all N tokens
     //logits_out.resize(n_tokens*n_vocab);
     //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
+    //ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab);
 
     // extract logits only for the last token
     logits_out.resize(n_vocab);
-    memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
+    //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
+    ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
 
     if (n_tokens > 1) {
         //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
@@ -2794,7 +2765,7 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
                 --j;
             }
             if (!found) {
-                log("unknown token\n");
+                WHISPER_LOG_ERROR("unknown token\n");
                 ++i;
             }
         }
@@ -2857,45 +2828,48 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
 
 struct whisper_state * whisper_init_state(whisper_context * ctx) {
     fill_sin_cos_table();
+
     whisper_state * state = new whisper_state;
 
-    if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
-        log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
+    state->backend = whisper_backend_init(ctx->params);
+
+    if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
+        WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
         delete state;
         return nullptr;
     }
 
     {
         const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
-        log("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
     }
 
-    if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
-        log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
+    if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
+        WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
         delete state;
         return nullptr;
     }
 
     {
         const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
-        log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
     }
 
 #ifdef WHISPER_USE_COREML
     const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
 
-    log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
-    log("%s: first run on a device may take a while ...\n", __func__);
+    WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
+    WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
 
     state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
     if (!state->ctx_coreml) {
-        log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
+        WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
 #ifndef WHISPER_COREML_ALLOW_FALLBACK
         delete state;
         return nullptr;
 #endif
     } else {
-        log("%s: Core ML model loaded\n", __func__);
+        WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__);
     }
 #endif
 
@@ -2912,37 +2886,37 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     // conv allocator
     {
-        whisper_allocr_graph_init(state->alloc_conv,
+        whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
                 [&]() {
                     return whisper_build_graph_conv(*ctx, *state, 0);
                 });
 
-        log("%s: compute buffer (conv)   = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: compute buffer (conv)   = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
     }
 
     // encoder allocator
     if (!whisper_encode_external(*state)) {
-        whisper_allocr_graph_init(state->alloc_encode,
+        whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
                 [&]() {
                     return whisper_build_graph_encoder(*ctx, *state);
                 });
 
-        log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
     }
 
     // cross allocator
     {
-        whisper_allocr_graph_init(state->alloc_cross,
+        whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
                 [&]() {
                     return whisper_build_graph_cross(*ctx, *state);
                 });
 
-        log("%s: compute buffer (cross)  = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: compute buffer (cross)  = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
     }
 
     // decoder allocator
     {
-        whisper_allocr_graph_init(state->alloc_decode,
+        whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
                 [&]() {
                     const auto & hparams = ctx->model.hparams;
 
@@ -2953,69 +2927,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
                     return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
                 });
 
-        log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
     }
 
-#ifdef GGML_USE_METAL
-    if (ctx->params.use_gpu) {
-        state->ctx_metal = ggml_metal_init(1);
-        if (!state->ctx_metal) {
-            log("%s: ggml_metal_init() failed\n", __func__);
-            delete state;
-            return nullptr;
-        }
-    }
-
-    if (state->ctx_metal) {
-        log("%s: Metal context initialized\n", __func__);
-
-        // this allocates all Metal resources and memory buffers
-
-        void * data_ptr  = NULL;
-        size_t data_size = 0;
-
-        // TODO: add mmap support
-        //if (params.use_mmap) {
-        //    data_ptr  = ctx->model.mapping->addr;
-        //    data_size = ctx->model.mapping->size;
-        //} else {
-        //    data_ptr  = ggml_get_mem_buffer(ctx->model.ctx);
-        //    data_size = ggml_get_mem_size  (ctx->model.ctx);
-        //}
-
-        data_ptr  = ggml_get_mem_buffer(ctx->model.ctx);
-        data_size = ggml_get_mem_size  (ctx->model.ctx);
-
-        const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
-
-        log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
-
-#define WHISPER_METAL_CHECK_BUF(result)              \
-        if (!(result)) {                                 \
-            log("%s: failed to add metal buffer\n", __func__); \
-            delete state;                                \
-            return nullptr;                              \
-        }
-
-        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
-
-        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv",   state->alloc_conv.meta.data(),   state->alloc_conv.meta.size(),   0));
-        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
-        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross",  state->alloc_cross.meta.data(),  state->alloc_cross.meta.size(),  0));
-        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
-
-        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv",   state->alloc_conv.data.data(),   state->alloc_conv.data.size(),   0));
-        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
-        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross",  state->alloc_cross.data.data(),  state->alloc_cross.data.size(),  0));
-        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
-
-        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross",  state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
-
-        WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
-#undef WHISPER_METAL_CHECK_BUF
-
-    }
-#endif
+    whisper_allocr_graph_realloc(state->alloc_conv,   ctx->backend);
+    whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
+    whisper_allocr_graph_realloc(state->alloc_cross,  ctx->backend);
+    whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
 
     state->rng = std::mt19937(0);
 
@@ -3036,7 +2954,7 @@ int whisper_ctx_init_openvino_encoder(
     return 1;
 #else
     if (!model_path && ctx->path_model.empty()) {
-        log("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
+        WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
         return 1;
     }
 
@@ -3056,15 +2974,15 @@ int whisper_ctx_init_openvino_encoder(
         path_cache = cache_dir;
     }
 
-    log("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
-    log("%s: first run on a device may take a while ...\n", __func__);
+    WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
+    WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
 
     ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
     if (!ctx->state->ctx_openvino) {
-        log("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
+        WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
         return 1;
     } else {
-        log("%s: OpenVINO model loaded\n", __func__);
+        WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__);
     }
 
     return 0;
@@ -3079,11 +2997,11 @@ struct whisper_context_params whisper_context_default_params() {
 }
 
 struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
-    log("%s: loading model from '%s'\n", __func__, path_model);
+    WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
 
     auto fin = std::ifstream(path_model, std::ios::binary);
     if (!fin) {
-        log("%s: failed to open '%s'\n", __func__, path_model);
+        WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
         return nullptr;
     }
 
@@ -3125,7 +3043,7 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
 
     buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
 
-    log("%s: loading model from buffer\n", __func__);
+    WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__);
 
     whisper_model_loader loader = {};
 
@@ -3161,7 +3079,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
 
     if (!whisper_model_load(loader, *ctx)) {
         loader->close(loader->context);
-        log("%s: failed to load model\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to load model\n", __func__);
         delete ctx;
         return nullptr;
     }
@@ -3256,13 +3174,6 @@ void whisper_free_state(struct whisper_state * state)
         }
 #endif
 
-#ifdef GGML_USE_METAL
-        if (state->ctx_metal) {
-            ggml_metal_free(state->ctx_metal);
-            state->ctx_metal = nullptr;
-        }
-#endif
-
 #ifdef WHISPER_USE_OPENVINO
         if (state->ctx_openvino != nullptr) {
             whisper_openvino_free(state->ctx_openvino);
@@ -3271,9 +3182,11 @@ void whisper_free_state(struct whisper_state * state)
 #endif
 
         whisper_allocr_free(state->alloc_conv);
-        whisper_allocr_free(state->alloc_decode);
-        whisper_allocr_free(state->alloc_cross);
         whisper_allocr_free(state->alloc_encode);
+        whisper_allocr_free(state->alloc_cross);
+        whisper_allocr_free(state->alloc_decode);
+
+        ggml_backend_free(state->backend);
 
         delete state;
     }
@@ -3284,12 +3197,15 @@ void whisper_free(struct whisper_context * ctx) {
         if (ctx->model.ctx) {
             ggml_free(ctx->model.ctx);
         }
-        if (ctx->model.buf) {
-            delete ctx->model.buf;
+
+        if (ctx->model.buffer) {
+            ggml_backend_buffer_free(ctx->model.buffer);
         }
 
         whisper_free_state(ctx->state);
 
+        ggml_backend_free(ctx->backend);
+
         delete ctx;
     }
 }
@@ -3308,7 +3224,7 @@ void whisper_free_params(struct whisper_full_params * params) {
 
 int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
     if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
-        log("%s: failed to compute mel spectrogram\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
         return -1;
     }
 
@@ -3322,7 +3238,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
 // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
 int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
     if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
-        log("%s: failed to compute mel spectrogram\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
         return -1;
     }
 
@@ -3350,7 +3266,7 @@ int whisper_set_mel_with_state(
                            int   n_len,
                            int   n_mel) {
     if (n_mel != ctx->model.filters.n_mel) {
-        log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
+        WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
         return -1;
     }
 
@@ -3374,7 +3290,7 @@ int whisper_set_mel(
 
 int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
     if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
-        log("%s: failed to eval\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return -1;
     }
 
@@ -3383,7 +3299,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
 
 int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
     if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
-        log("%s: failed to eval\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return -1;
     }
 
@@ -3394,7 +3310,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
     const int selected_decoder_id = 0;
 
     if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
-        log("%s: failed to eval\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return 1;
     }
 
@@ -3406,12 +3322,12 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
     const int selected_decoder_id = 0;
 
     if (ctx->state == nullptr) {
-        log("%s: ERROR state was not loaded.\n", __func__);
+        WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
         return false;
     }
 
     if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
-        log("%s: failed to eval\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return 1;
     }
 
@@ -3422,7 +3338,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
     const auto res = tokenize(ctx->vocab, text);
 
     if (n_max_tokens < (int) res.size()) {
-        log("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
+        WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
         return -1;
     }
 
@@ -3450,7 +3366,7 @@ int whisper_lang_id(const char * lang) {
             }
         }
 
-        log("%s: unknown language '%s'\n", __func__, lang);
+        WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang);
         return -1;
     }
     return g_lang.at(lang).first;
@@ -3463,7 +3379,7 @@ const char * whisper_lang_str(int id) {
         }
     }
 
-    log("%s: unknown language id %d\n", __func__, id);
+    WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
     return nullptr;
 }
 
@@ -3476,25 +3392,25 @@ int whisper_lang_auto_detect_with_state(
     const int seek = offset_ms/10;
 
     if (seek < 0) {
-        log("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
+        WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
         return -1;
     }
 
     if (seek >= state->mel.n_len_org) {
-        log("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
+        WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
         return -2;
     }
 
     // run the encoder
     if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
-        log("%s: failed to encode\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
         return -6;
     }
 
     const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
 
     if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
-        log("%s: failed to decode\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
         return -7;
     }
 
@@ -3694,8 +3610,8 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
 void whisper_print_timings(struct whisper_context * ctx) {
     const int64_t t_end_us = ggml_time_us();
 
-    log("\n");
-    log("%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
+    WHISPER_LOG_INFO("\n");
+    WHISPER_LOG_INFO("%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
     if (ctx->state != nullptr) {
 
         const int32_t n_sample = std::max(1, ctx->state->n_sample);
@@ -3703,14 +3619,14 @@ void whisper_print_timings(struct whisper_context * ctx) {
         const int32_t n_decode = std::max(1, ctx->state->n_decode);
         const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
 
-        log("%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
-        log("%s:      mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
-        log("%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
-        log("%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
-        log("%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
-        log("%s:   prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
+        WHISPER_LOG_INFO("%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
+        WHISPER_LOG_INFO("%s:      mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
+        WHISPER_LOG_INFO("%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
+        WHISPER_LOG_INFO("%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
+        WHISPER_LOG_INFO("%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
+        WHISPER_LOG_INFO("%s:   prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
     }
-    log("%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
+    WHISPER_LOG_INFO("%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
 }
 
 void whisper_reset_timings(struct whisper_context * ctx) {
@@ -3762,6 +3678,7 @@ const char * whisper_print_system_info(void) {
     s += "SSE3 = "      + std::to_string(ggml_cpu_has_sse3())      + " | ";
     s += "SSSE3 = "     + std::to_string(ggml_cpu_has_ssse3())     + " | ";
     s += "VSX = "       + std::to_string(ggml_cpu_has_vsx())       + " | ";
+    s += "CUDA = "      + std::to_string(ggml_cpu_has_cublas())    + " | ";
     s += "COREML = "    + std::to_string(whisper_has_coreml())     + " | ";
     s += "OPENVINO = "  + std::to_string(whisper_has_openvino())   + " | ";
 
@@ -4056,7 +3973,7 @@ static void whisper_process_logits(
             const bool last_was_timestamp        = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
             const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
 
-            //log("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
+            //WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
 
             if (last_was_timestamp) {
                 if (penultimate_was_timestamp) {
@@ -4132,7 +4049,7 @@ static void whisper_process_logits(
 
             const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
 
-            //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
+            //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
 
             if (timestamp_logprob > max_text_token_logprob) {
                 for (int i = 0; i < vocab.token_beg; ++i) {
@@ -4427,8 +4344,10 @@ static bool whisper_kv_swap_fast(
     for (auto & i : two_copy) {
         // make a copy of KV caches
         WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
-        memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
-        memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
+        //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
+        //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
+        ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
+        ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
     }
 
     // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
@@ -4441,13 +4360,17 @@ static bool whisper_kv_swap_fast(
         if (two_copy.find(view[i]) != two_copy.end()) {
             // modify KV caches of decoder using data from kv_swap_bufs
             WHISPER_PRINT_DEBUG("%s: two-copy decoder using   swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
-            memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
-            memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
+            //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
+            //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
+            ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
+            ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
         } else {
             // modify KV caches of decoder using data from correspond decoder KV caches directly
             WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers:      %d  -> %d\n", __func__, view[i], i);
-            memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
-            memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
+            //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
+            //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
+            ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
+            ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
         }
     }
 
@@ -4461,13 +4384,17 @@ static bool whisper_kv_swap_fast(
         if (two_copy.find(view[i]) != two_copy.end()) {
             // modify KV caches of decoder using data from kv_swap_bufs
             WHISPER_PRINT_DEBUG("%s: one-copy decoder using   swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
-            memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
-            memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
+            //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
+            //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
+            ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
+            ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
         } else {
             // modify KV caches of decoder using data from correspond decoder KV caches directly
             WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers:      %d  -> %d\n", __func__, view[i], i);
-            memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
-            memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
+            //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
+            //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
+            ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
+            ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
         }
     }
 
@@ -4495,11 +4422,11 @@ int whisper_full_with_state(
         // compute log mel spectrogram
         if (params.speed_up) {
             // TODO: Replace PV with more advanced algorithm
-            log("%s: failed to compute log mel spectrogram\n", __func__);
+            WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
             return -1;
         } else {
             if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
-                log("%s: failed to compute log mel spectrogram\n", __func__);
+                WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
                 return -2;
             }
         }
@@ -4511,13 +4438,13 @@ int whisper_full_with_state(
 
         const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
         if (lang_id < 0) {
-            log("%s: failed to auto-detect language\n", __func__);
+            WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
             return -3;
         }
         state->lang_id = lang_id;
         params.language = whisper_lang_str(lang_id);
 
-        log("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
+        WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
         if (params.detect_language) {
             return 0;
         }
@@ -4575,8 +4502,8 @@ int whisper_full_with_state(
 
         if (decoder.kv_self.ctx == nullptr) {
             decoder.kv_self = state->decoders[0].kv_self;
-            if (!kv_cache_reinit(decoder.kv_self)) {
-                log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
+            if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) {
+                WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
                 return -4;
             }
 
@@ -4587,23 +4514,6 @@ int whisper_full_with_state(
             decoder.probs.resize   (ctx->vocab.n_vocab);
             decoder.logits.resize  (ctx->vocab.n_vocab);
             decoder.logprobs.resize(ctx->vocab.n_vocab);
-
-            // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
-#ifdef GGML_USE_METAL
-            if (state->ctx_metal) {
-#define WHISPER_METAL_CHECK_BUF(result)              \
-                if (!(result)) {                                 \
-                    log("%s: failed to add metal buffer\n", __func__); \
-                    return 0;                              \
-                }
-
-                const std::string kv_name = "kv_self_" + std::to_string(j);
-                auto & kv_self = decoder.kv_self;
-
-                WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
-#undef WHISPER_METAL_CHECK_BUF
-            }
-#endif
         }
     }
 
@@ -4637,7 +4547,7 @@ int whisper_full_with_state(
 
     // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
     if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
-        log("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
+        WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
         return -5;
     }
     state->exp_n_audio_ctx = params.audio_ctx;
@@ -4662,7 +4572,7 @@ int whisper_full_with_state(
         // distilled models require the "no_timestamps" token
         // TODO: add input parameter (#1229)
         if (is_distil) {
-            log("%s: using distilled model - forcing no_timestamps\n", __func__);
+            WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
             prompt_init.push_back(whisper_token_not(ctx));
         }
     }
@@ -4699,14 +4609,14 @@ int whisper_full_with_state(
 
         if (params.encoder_begin_callback) {
             if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
-                log("%s: encoder_begin_callback returned false - aborting\n", __func__);
+                WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__);
                 break;
             }
         }
 
         // encode audio features starting at offset seek
         if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
-            log("%s: failed to encode\n", __func__);
+            WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
             return -6;
         }
 
@@ -4789,7 +4699,7 @@ int whisper_full_with_state(
                 WHISPER_PRINT_DEBUG("\n\n");
 
                 if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
-                    log("%s: failed to decode\n", __func__);
+                    WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
                     return -7;
                 }
 
@@ -4803,8 +4713,11 @@ int whisper_full_with_state(
                     for (int j = 1; j < n_decoders_cur; ++j) {
                         auto & decoder = state->decoders[j];
 
-                        memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
-                        memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
+                        // TODO: fix CUDA
+                        //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
+                        //memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
+                        ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
+                        ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
 
                         decoder.kv_self.n += prompt.size();
 
@@ -5013,7 +4926,7 @@ int whisper_full_with_state(
                     //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
 
                     if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
-                        log("%s: failed to decode\n", __func__);
+                        WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
                         return -8;
                     }
 
@@ -5339,12 +5252,12 @@ int whisper_full_parallel(
     ctx->state->t_decode_us /= n_processors;
 
     // print information about the audio boundaries
-    log("\n");
-    log("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
+    WHISPER_LOG_WARN("\n");
+    WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
     for (int i = 0; i < n_processors - 1; ++i) {
-        log("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
+        WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
     }
-    log("%s: the transcription quality may be degraded near these boundaries\n", __func__);
+    WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__);
 
     return ret;
 }
@@ -5586,12 +5499,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
             double tsum = 0.0;
 
             // heat-up
-            ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
+            ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
 
             for (int i = 0; i < n_max; ++i) {
                 const int64_t t0 = ggml_time_us();
 
-                ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
+                ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
 
                 const int64_t t1 = ggml_time_us();
 
@@ -5709,7 +5622,7 @@ static void whisper_exp_compute_token_level_timestamps(
     const int n_samples = state.energy.size();
 
     if (n_samples == 0) {
-        log("%s: no signal data available\n", __func__);
+        WHISPER_LOG_ERROR("%s: no signal data available\n", __func__);
         return;
     }
 
@@ -5930,6 +5843,38 @@ static void whisper_exp_compute_token_level_timestamps(
     //}
 }
 
-void whisper_set_log_callback(whisper_log_callback callback) {
-    whisper_log = callback;
+void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
+    g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
+    g_state.log_callback_user_data = user_data;
+}
+
+static void whisper_log_internal_v(ggml_log_level level, const char * format, va_list args) {
+    va_list args_copy;
+    va_copy(args_copy, args);
+    char buffer[128];
+    int len = vsnprintf(buffer, 128, format, args);
+    if (len < 128) {
+        g_state.log_callback(level, buffer, g_state.log_callback_user_data);
+    } else {
+        char* buffer2 = new char[len+1];
+        vsnprintf(buffer2, len+1, format, args_copy);
+        buffer2[len] = 0;
+        g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
+        delete[] buffer2;
+    }
+    va_end(args_copy);
+}
+
+static void whisper_log_internal(ggml_log_level level, const char * format, ...) {
+    va_list args;
+    va_start(args, format);
+    whisper_log_internal_v(level, format, args);
+    va_end(args);
+}
+
+static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
+    (void) level;
+    (void) user_data;
+    fputs(text, stderr);
+    fflush(stderr);
 }
index ed1612b4bc8082ea1a47c509c85515113e0473e4..0ea5237e5f2f7c0d51a2432439f8d425e23b78e6 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -1,6 +1,8 @@
 #ifndef WHISPER_H
 #define WHISPER_H
 
+#include "ggml.h"
+
 #include <stddef.h>
 #include <stdint.h>
 #include <stdbool.h>
@@ -110,15 +112,15 @@ extern "C" {
     // Various functions for loading a ggml whisper model.
     // Allocate (almost) all memory needed for the model.
     // Return NULL on failure
-    WHISPER_API struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params);
-    WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
-    WHISPER_API struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_from_file_with_params  (const char * path_model,              struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size,    struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_with_params            (struct whisper_model_loader * loader, struct whisper_context_params params);
 
     // These are the same as the above, but the internal state of the context is not allocated automatically
     // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
-    WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);
-    WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
-    WHISPER_API struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state  (const char * path_model,              struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size,    struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_with_params_no_state            (struct whisper_model_loader * loader, struct whisper_context_params params);
 
     WHISPER_DEPRECATED(
         WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model),
@@ -570,8 +572,7 @@ extern "C" {
 
     // Control logging output; default behavior is to print to stderr
 
-    typedef void (*whisper_log_callback)(const char * line);
-    WHISPER_API void whisper_set_log_callback(whisper_log_callback callback);
+    WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
 
 #ifdef __cplusplus
 }