]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
coreml : use the correct `n_mel` value (#1458)
authorXiao-Yong Jin <redacted>
Wed, 8 Nov 2023 20:01:41 +0000 (14:01 -0600)
committerGitHub <redacted>
Wed, 8 Nov 2023 20:01:41 +0000 (20:01 +0000)
coreml/whisper-encoder-impl.h
coreml/whisper-encoder.h
coreml/whisper-encoder.mm
models/convert-whisper-to-coreml.py
whisper.cpp

index ecb61555b94037087f3d4c6d7eb865eb27c1db39..7b83cd906c5c5401dbb27437751cb6f75521b932 100644 (file)
@@ -123,7 +123,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
 
 /**
     Make a prediction using the convenience interface
-    @param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
+    @param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
     @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
     @return the prediction as whisper_encoder_implOutput
 */
index 84bbe4165052eefea667a18f071bcdba958879b2..508df7c1e94b8277a94e7492b877e1a34dcf9a4b 100644 (file)
@@ -3,6 +3,8 @@
 // Code is derived from the work of Github user @wangchou
 // ref: https://github.com/wangchou/callCoreMLFromCpp
 
+#include <stdint.h>
+
 #if __cplusplus
 extern "C" {
 #endif
@@ -14,6 +16,8 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx);
 
 void whisper_coreml_encode(
         const whisper_coreml_context * ctx,
+                             int64_t   n_ctx,
+                             int64_t   n_mel,
                                float * mel,
                                float * out);
 
index 499edaed434069b5ce4523265864a3865f417469..8e93f180c1b6ef4f900c1a61afac3983e0687634 100644 (file)
@@ -48,13 +48,15 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx) {
 
 void whisper_coreml_encode(
         const whisper_coreml_context * ctx,
+                             int64_t   n_ctx,
+                             int64_t   n_mel,
                                float * mel,
                                float * out) {
     MLMultiArray * inMultiArray = [
         [MLMultiArray alloc] initWithDataPointer: mel
-                                           shape: @[@1, @80, @3000]
+                                           shape: @[@1, @(n_mel), @(n_ctx)]
                                         dataType: MLMultiArrayDataTypeFloat32
-                                         strides: @[@(240000), @(3000), @1]
+                                         strides: @[@(n_ctx*n_mel), @(n_ctx), @1]
                                      deallocator: nil
                                            error: nil
     ];
index adbbd1099cb44acaaa739c814e34987139f5d1c7..7e09f5ba6a900eb0079bcf3f87ee7c95f61b4e15 100644 (file)
@@ -252,7 +252,7 @@ class WhisperANE(Whisper):
 def convert_encoder(hparams, model, quantize=False):
     model.eval()
 
-    input_shape = (1, 80, 3000)
+    input_shape = (1, hparams.n_mels, 3000)
     input_data = torch.randn(input_shape)
     traced_model = torch.jit.trace(model, input_data)
 
@@ -302,7 +302,7 @@ if __name__ == "__main__":
     parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
     args = parser.parse_args()
 
-    if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
+    if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large", "large-v1", "large-v2"]:
         raise ValueError("Invalid model name")
 
     whisper = load_model(args.model).cpu()
index 9a1eb2152fed45e0a784e595c38cb299b7b5450e..d36349b0ce340153540ec6d518debec7dcee9599 100644 (file)
@@ -1603,7 +1603,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
         ggml_allocr_alloc(alloc, cur);
 
         if (!ggml_allocr_is_measure(alloc)) {
-            whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
+            whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
         }
 #endif
 #ifdef WHISPER_USE_OPENVINO