/**
Make a prediction using the convenience interface
- @param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
+ @param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
// Code is derived from the work of Github user @wangchou
// ref: https://github.com/wangchou/callCoreMLFromCpp
+#include <stdint.h>
+
#if __cplusplus
extern "C" {
#endif
void whisper_coreml_encode(
const whisper_coreml_context * ctx,
+ int64_t n_ctx,
+ int64_t n_mel,
float * mel,
float * out);
void whisper_coreml_encode(
const whisper_coreml_context * ctx,
+ int64_t n_ctx,
+ int64_t n_mel,
float * mel,
float * out) {
MLMultiArray * inMultiArray = [
[MLMultiArray alloc] initWithDataPointer: mel
- shape: @[@1, @80, @3000]
+ shape: @[@1, @(n_mel), @(n_ctx)]
dataType: MLMultiArrayDataTypeFloat32
- strides: @[@(240000), @(3000), @1]
+ strides: @[@(n_ctx*n_mel), @(n_ctx), @1]
deallocator: nil
error: nil
];
def convert_encoder(hparams, model, quantize=False):
model.eval()
- input_shape = (1, 80, 3000)
+ input_shape = (1, hparams.n_mels, 3000)
input_data = torch.randn(input_shape)
traced_model = torch.jit.trace(model, input_data)
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
args = parser.parse_args()
- if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
+ if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large", "large-v1", "large-v2"]:
raise ValueError("Invalid model name")
whisper = load_model(args.model).cpu()
ggml_allocr_alloc(alloc, cur);
if (!ggml_allocr_is_measure(alloc)) {
- whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
+ whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
}
#endif
#ifdef WHISPER_USE_OPENVINO