]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : whisper.cpp (ARM 32-bit, abort callback, wav_writer, etc.) (#602)
authorGeorgi Gerganov <redacted>
Fri, 3 Nov 2023 20:26:27 +0000 (22:26 +0200)
committerGitHub <redacted>
Fri, 3 Nov 2023 20:26:27 +0000 (22:26 +0200)
examples/common.h
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h
scripts/sync-whisper.sh
src/ggml-impl.h
src/ggml-metal.m
src/ggml-quants.c
src/ggml.c

index 1d4e9c37c9f50a486497bfb1ad24557b138ffc35..9a94bab7a7099c569f6f559ca1e6a15547ce4f3e 100644 (file)
@@ -7,6 +7,8 @@
 #include <vector>
 #include <random>
 #include <thread>
+#include <ctime>
+#include <fstream>
 
 #define COMMON_SAMPLE_RATE 16000
 
@@ -142,6 +144,104 @@ bool read_wav(
         std::vector<std::vector<float>> & pcmf32s,
         bool stereo);
 
+// Write PCM data into WAV audio file
+class wav_writer {
+private:
+    std::ofstream file;
+    uint32_t dataSize = 0;
+    std::string wav_filename;
+
+    bool write_header(const uint32_t sample_rate,
+                      const uint16_t bits_per_sample,
+                      const uint16_t channels) {
+
+        file.write("RIFF", 4);
+        file.write("\0\0\0\0", 4);    // Placeholder for file size
+        file.write("WAVE", 4);
+        file.write("fmt ", 4);
+
+        const uint32_t sub_chunk_size = 16;
+        const uint16_t audio_format = 1;      // PCM format
+        const uint32_t byte_rate = sample_rate * channels * bits_per_sample / 8;
+        const uint16_t block_align = channels * bits_per_sample / 8;
+
+        file.write(reinterpret_cast<const char *>(&sub_chunk_size), 4);
+        file.write(reinterpret_cast<const char *>(&audio_format), 2);
+        file.write(reinterpret_cast<const char *>(&channels), 2);
+        file.write(reinterpret_cast<const char *>(&sample_rate), 4);
+        file.write(reinterpret_cast<const char *>(&byte_rate), 4);
+        file.write(reinterpret_cast<const char *>(&block_align), 2);
+        file.write(reinterpret_cast<const char *>(&bits_per_sample), 2);
+        file.write("data", 4);
+        file.write("\0\0\0\0", 4);    // Placeholder for data size
+
+        return true;
+    }
+
+    // It is assumed that PCM data is normalized to a range from -1 to 1
+    bool write_audio(const float * data, size_t length) {
+        for (size_t i = 0; i < length; ++i) {
+            const auto intSample = static_cast<const int16_t>(data[i] * 32767);
+            file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
+            dataSize += sizeof(int16_t);
+        }
+        if (file.is_open()) {
+            file.seekp(4, std::ios::beg);
+            uint32_t fileSize = 36 + dataSize;
+            file.write(reinterpret_cast<char *>(&fileSize), 4);
+            file.seekp(40, std::ios::beg);
+            file.write(reinterpret_cast<char *>(&dataSize), 4);
+            file.seekp(0, std::ios::end);
+        }
+        return true;
+    }
+
+    bool open_wav(const std::string & filename) {
+        if (filename != wav_filename) {
+            if (file.is_open()) {
+                file.close();
+            }
+        }
+        if (!file.is_open()) {
+            file.open(filename, std::ios::binary);
+            wav_filename = filename;
+            dataSize = 0;
+        }
+        return file.is_open();
+    }
+
+public:
+    bool open(const std::string & filename,
+              const    uint32_t   sample_rate,
+              const    uint16_t   bits_per_sample,
+              const    uint16_t   channels) {
+
+        if (open_wav(filename)) {
+            write_header(sample_rate, bits_per_sample, channels);
+        } else {
+            return false;
+        }
+
+        return true;
+    }
+
+    bool close() {
+        file.close();
+        return true;
+    }
+
+    bool write(const float * data, size_t length) {
+        return write_audio(data, length);
+    }
+
+    ~wav_writer() {
+        if (file.is_open()) {
+            file.close();
+        }
+    }
+};
+
+
 // Apply a high-pass frequency filter to PCM audio
 // Suppresses frequencies below cutoff Hz
 void high_pass_filter(
index 60c1cca756a683de1321e72a38838957e464b629..bed0789f9ad06ad4b4cd04d2b8bdff586abb1c98 100644 (file)
@@ -83,6 +83,7 @@ struct whisper_params {
     bool output_wts      = false;
     bool output_csv      = false;
     bool output_jsn      = false;
+    bool output_jsn_full = false;
     bool output_lrc      = false;
     bool print_special   = false;
     bool print_colors    = false;
@@ -151,6 +152,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-fp"   || arg == "--font-path")       { params.font_path       = argv[++i]; }
         else if (arg == "-ocsv" || arg == "--output-csv")      { params.output_csv      = true; }
         else if (arg == "-oj"   || arg == "--output-json")     { params.output_jsn      = true; }
+        else if (arg == "-ojf"  || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
         else if (arg == "-of"   || arg == "--output-file")     { params.fname_out.emplace_back(argv[++i]); }
         else if (arg == "-ps"   || arg == "--print-special")   { params.print_special   = true; }
         else if (arg == "-pc"   || arg == "--print-colors")    { params.print_colors    = true; }
@@ -206,6 +208,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -fp,       --font-path         [%-7s] path to a monospace font for karaoke video\n",     params.font_path.c_str());
     fprintf(stderr, "  -ocsv,     --output-csv        [%-7s] output result in a CSV file\n",                    params.output_csv ? "true" : "false");
     fprintf(stderr, "  -oj,       --output-json       [%-7s] output result in a JSON file\n",                   params.output_jsn ? "true" : "false");
+    fprintf(stderr, "  -ojf,      --output-json-full  [%-7s] include more information in the JSON file\n",      params.output_jsn_full ? "true" : "false");
     fprintf(stderr, "  -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n",      "");
     fprintf(stderr, "  -ps,       --print-special     [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
     fprintf(stderr, "  -pc,       --print-colors      [%-7s] print colors\n",                                   params.print_colors ? "true" : "false");
@@ -511,7 +514,12 @@ bool output_score(struct whisper_context * ctx, const char * fname, const whispe
     return true;
 }
 
-bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
+bool output_json(
+             struct whisper_context * ctx,
+                         const char * fname,
+               const whisper_params & params,
+    std::vector<std::vector<float>>   pcmf32s,
+                               bool   full) {
     std::ofstream fout(fname);
     int indent = 0;
 
@@ -528,7 +536,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
     auto end_arr = [&](bool end) {
         indent--;
         doindent();
-        fout << (end ? "]\n" : "},\n");
+        fout << (end ? "]\n" : "],\n");
     };
 
     auto start_obj = [&](const char *name) {
@@ -569,12 +577,29 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
         end_value(end);
     };
 
+    auto value_f = [&](const char *name, const float val, bool end) {
+        start_value(name);
+        fout << val;
+        end_value(end);
+    };
+
     auto value_b = [&](const char *name, const bool val, bool end) {
         start_value(name);
         fout << (val ? "true" : "false");
         end_value(end);
     };
 
+    auto times_o = [&](int64_t t0, int64_t t1, bool end) {
+        start_obj("timestamps");
+        value_s("from", to_timestamp(t0, true).c_str(), false);
+        value_s("to", to_timestamp(t1, true).c_str(), true);
+        end_obj(false);
+        start_obj("offsets");
+        value_i("from", t0 * 10, false);
+        value_i("to", t1 * 10, true);
+        end_obj(end);
+    };
+
     if (!fout.is_open()) {
         fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
         return false;
@@ -620,15 +645,26 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
                 const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
 
                 start_obj(nullptr);
-                    start_obj("timestamps");
-                        value_s("from", to_timestamp(t0, true).c_str(), false);
-                        value_s("to", to_timestamp(t1, true).c_str(), true);
-                    end_obj(false);
-                    start_obj("offsets");
-                        value_i("from", t0 * 10, false);
-                        value_i("to", t1 * 10, true);
-                    end_obj(false);
-                    value_s("text", text, !params.diarize && !params.tinydiarize);
+                    times_o(t0, t1, false);
+                    value_s("text", text, !params.diarize && !params.tinydiarize && !full);
+
+                    if (full) {
+                        start_arr("tokens");
+                        const int n = whisper_full_n_tokens(ctx, i);
+                        for (int j = 0; j < n; ++j) {
+                            auto token = whisper_full_get_token_data(ctx, i, j);
+                            start_obj(nullptr);
+                                value_s("text", whisper_token_to_str(ctx, token.id), false);
+                                if(token.t0 > -1 && token.t1 > -1) {
+                                    // If we have per-token timestamps, write them out
+                                    times_o(token.t0, token.t1, false);
+                                }
+                                value_i("id", token.id, false);
+                                value_f("p", token.p, true);
+                            end_obj(j == (n - 1));
+                        }
+                        end_arr(!params.diarize && !params.tinydiarize);
+                    }
 
                     if (params.diarize && pcmf32s.size() == 2) {
                         value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
@@ -912,7 +948,7 @@ int main(int argc, char ** argv) {
             wparams.offset_ms        = params.offset_t_ms;
             wparams.duration_ms      = params.duration_ms;
 
-            wparams.token_timestamps = params.output_wts || params.max_len > 0;
+            wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0;
             wparams.thold_pt         = params.word_thold;
             wparams.max_len          = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
             wparams.split_on_word    = params.split_on_word;
@@ -944,8 +980,9 @@ int main(int argc, char ** argv) {
                 wparams.progress_callback_user_data = &user_data;
             }
 
-            // example for abort mechanism
-            // in this example, we do not abort the processing, but we could if the flag is set to true
+            // examples for abort mechanism
+            // in examples below, we do not abort the processing, but we could if the flag is set to true
+
             // the callback is called before every encoder run - if it returns false, the processing is aborted
             {
                 static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
@@ -957,6 +994,17 @@ int main(int argc, char ** argv) {
                 wparams.encoder_begin_callback_user_data = &is_aborted;
             }
 
+            // the callback is called before every computation - if it returns true, the computation is aborted
+            {
+                static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
+
+                wparams.abort_callback = [](void * user_data) {
+                    bool is_aborted = *(bool*)user_data;
+                    return is_aborted;
+                };
+                wparams.abort_callback_user_data = &is_aborted;
+            }
+
             if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
                 return 10;
@@ -1000,7 +1048,7 @@ int main(int argc, char ** argv) {
             // output to JSON file
             if (params.output_jsn) {
                 const auto fname_jsn = fname_out + ".json";
-                output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
+                output_json(ctx, fname_jsn.c_str(), params, pcmf32s, params.output_jsn_full);
             }
 
             // output to LRC file
index 6d4a2c127186fe1a8fd2068d9410f6a7adadaab2..17ef4d9e8abbeb1a7b4eb8e5dc2ced88463aeaf6 100644 (file)
@@ -120,14 +120,23 @@ static void byteswap_tensor(ggml_tensor * tensor) {
 //#define WHISPER_USE_FLASH_ATTN
 //#define WHISPER_USE_FLASH_FF
 #define WHISPER_MAX_DECODERS 16
+#define WHISPER_MAX_NODES 4096
 
 //
 // ggml helpers
 //
 
-static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
+static void ggml_graph_compute_helper(
+        std::vector<uint8_t> & buf,
+                 ggml_cgraph * graph,
+                         int   n_threads,
+      whisper_abort_callback   abort_callback,
+                        void * abort_callback_data) {
     struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
 
+    plan.abort_callback = abort_callback;
+    plan.abort_callback_data = abort_callback_data;
+
     if (plan.work_size > 0) {
         buf.resize(plan.work_size);
         plan.work_data = buf.data();
@@ -655,7 +664,7 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::funct
     auto & meta  = allocr.meta;
     auto & data  = allocr.data;
 
-    meta.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
+    meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
 
     alloc = ggml_allocr_new_measure(tensor_alignment);
 
@@ -1608,7 +1617,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
 
     struct ggml_context * ctx0 = ggml_init(params);
 
-    ggml_cgraph * gf = ggml_new_graph(ctx0);
+    ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
 
     ggml_allocr * alloc = wstate.alloc_encode.alloc;
 
@@ -1922,7 +1931,9 @@ static bool whisper_encode_internal(
         whisper_context & wctx,
           whisper_state & wstate,
               const int   mel_offset,
-              const int   n_threads) {
+              const int   n_threads,
+ whisper_abort_callback   abort_callback,
+                   void * abort_callback_data) {
     const int64_t t_start_us = ggml_time_us();
 
     // conv
@@ -1936,7 +1947,7 @@ static bool whisper_encode_internal(
         ggml_allocr_alloc_graph(alloc, gf);
 
         if (!whisper_encode_external(wstate)) {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
         }
     }
 
@@ -1955,10 +1966,10 @@ static bool whisper_encode_internal(
             ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
             ggml_metal_graph_compute(wstate.ctx_metal, gf);
         } else {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
         }
 #else
-        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
 #endif
     }
 
@@ -1977,10 +1988,10 @@ static bool whisper_encode_internal(
             ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
             ggml_metal_graph_compute(wstate.ctx_metal, gf);
         } else {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
         }
 #else
-        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
 #endif
     }
 
@@ -2024,7 +2035,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
     struct ggml_context * ctx0 = ggml_init(params);
 
-    ggml_cgraph * gf = ggml_new_graph(ctx0);
+    ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
 
     ggml_allocr * alloc = wstate.alloc_decode.alloc;
 
@@ -2346,7 +2357,9 @@ static bool whisper_decode_internal(
     const whisper_token * tokens,
               const int   n_tokens,
               const int   n_past,
-              const int   n_threads) {
+              const int   n_threads,
+ whisper_abort_callback   abort_callback,
+                   void * abort_callback_data) {
     const int64_t t_start_us = ggml_time_us();
 
     const auto & model   = wctx.model;
@@ -2375,10 +2388,10 @@ static bool whisper_decode_internal(
             ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
             ggml_metal_graph_compute(wstate.ctx_metal, gf);
         } else {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
         }
 #else
-        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
 #endif
     }
 
@@ -3290,7 +3303,7 @@ int whisper_set_mel(
 }
 
 int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
-    if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
+    if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
         log("%s: failed to eval\n", __func__);
         return -1;
     }
@@ -3299,7 +3312,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
 }
 
 int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
-    if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
+    if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
         log("%s: failed to eval\n", __func__);
         return -1;
     }
@@ -3310,7 +3323,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
 int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
     const int selected_decoder_id = 0;
 
-    if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
+    if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
         log("%s: failed to eval\n", __func__);
         return 1;
     }
@@ -3327,7 +3340,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
         return false;
     }
 
-    if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
+    if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
         log("%s: failed to eval\n", __func__);
         return 1;
     }
@@ -4594,7 +4607,7 @@ int whisper_full_with_state(
         }
 
         // encode audio features starting at offset seek
-        if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
+        if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
             log("%s: failed to encode\n", __func__);
             return -6;
         }
@@ -4677,7 +4690,7 @@ int whisper_full_with_state(
                 }
                 WHISPER_PRINT_DEBUG("\n\n");
 
-                if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
+                if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
                     log("%s: failed to decode\n", __func__);
                     return -7;
                 }
@@ -4901,7 +4914,7 @@ int whisper_full_with_state(
 
                     //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
 
-                    if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
+                    if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
                         log("%s: failed to decode\n", __func__);
                         return -8;
                     }
@@ -5270,6 +5283,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
     return ctx->state->result_all[i_segment].t1;
 }
 
+bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
+    return state->result_all[i_segment].speaker_turn_next;
+}
+
 bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
     return ctx->state->result_all[i_segment].speaker_turn_next;
 }
@@ -5471,12 +5488,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
             double tsum = 0.0;
 
             // heat-up
-            ggml_graph_compute_helper(work, gf, n_threads);
+            ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
 
             for (int i = 0; i < n_max; ++i) {
                 const int64_t t0 = ggml_time_us();
 
-                ggml_graph_compute_helper(work, gf, n_threads);
+                ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
 
                 const int64_t t1 = ggml_time_us();
 
index 73ab4d799a23ad73bf9a6406fd9e503116bb8917..c3118c9c99b7ef0d6b217f4f6ff8ceae1d0a2b8f 100644 (file)
@@ -334,6 +334,11 @@ extern "C" {
     // If it returns false, the computation is aborted
     typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
 
+    // Abort callback
+    // If not NULL, called before ggml computation
+    // If it returns true, the computation is aborted
+    typedef bool (*whisper_abort_callback)(void * user_data);
+
     // Logits filter callback
     // Can be used to modify the logits before sampling
     // If not NULL, called after applying temperature to logits
@@ -428,6 +433,10 @@ extern "C" {
         whisper_encoder_begin_callback encoder_begin_callback;
         void * encoder_begin_callback_user_data;
 
+        // called each time before ggml computation starts
+        whisper_abort_callback abort_callback;
+        void * abort_callback_user_data;
+
         // called by each decoder to filter obtained logits
         whisper_logits_filter_callback logits_filter_callback;
         void * logits_filter_callback_user_data;
@@ -485,6 +494,7 @@ extern "C" {
 
     // Get whether the next segment is predicted as a speaker turn
     WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
+    WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment);
 
     // Get the text of the specified segment
     WHISPER_API const char * whisper_full_get_segment_text           (struct whisper_context * ctx, int i_segment);
index c17a091da9b49b4ba1d46297d84046dfc9b606e9..6976c77dc22e78665971ae6b780178dc2d83dd95 100755 (executable)
@@ -1,20 +1,31 @@
 #!/bin/bash
 
 cp -rpv ../whisper.cpp/ggml.c                         src/ggml.c
+cp -rpv ../whisper.cpp/ggml-impl.h                    src/ggml-impl.h
 cp -rpv ../whisper.cpp/ggml-alloc.c                   src/ggml-alloc.c
-cp -rpv ../whisper.cpp/ggml-cuda.h                    src/ggml-cuda.h
+cp -rpv ../whisper.cpp/ggml-backend-impl.h            src/ggml-backend-impl.h
+cp -rpv ../whisper.cpp/ggml-backend.c                 src/ggml-backend.c
 cp -rpv ../whisper.cpp/ggml-cuda.cu                   src/ggml-cuda.cu
-cp -rpv ../whisper.cpp/ggml-opencl.h                  src/ggml-opencl.h
-cp -rpv ../whisper.cpp/ggml-opencl.cpp                src/ggml-opencl.cpp
+cp -rpv ../whisper.cpp/ggml-cuda.h                    src/ggml-cuda.h
 cp -rpv ../whisper.cpp/ggml-metal.h                   src/ggml-metal.h
 cp -rpv ../whisper.cpp/ggml-metal.m                   src/ggml-metal.m
 cp -rpv ../whisper.cpp/ggml-metal.metal               src/ggml-metal.metal
+#cp -rpv ../whisper.cpp/ggml-mpi.h                     src/ggml-mpi.h
+#cp -rpv ../whisper.cpp/ggml-mpi.m                     src/ggml-mpi.m
+cp -rpv ../whisper.cpp/ggml-opencl.cpp                src/ggml-opencl.cpp
+cp -rpv ../whisper.cpp/ggml-opencl.h                  src/ggml-opencl.h
+cp -rpv ../whisper.cpp/ggml-quants.c                  src/ggml-quants.c
+cp -rpv ../whisper.cpp/ggml-quants.h                  src/ggml-quants.h
+
 cp -rpv ../whisper.cpp/ggml.h                         include/ggml/ggml.h
 cp -rpv ../whisper.cpp/ggml-alloc.h                   include/ggml/ggml-alloc.h
+cp -rpv ../whisper.cpp/ggml-backend.h                 include/ggml/ggml-backend.h
+
 cp -rpv ../whisper.cpp/examples/common.h              examples/common.h
 cp -rpv ../whisper.cpp/examples/common.cpp            examples/common.cpp
 cp -rpv ../whisper.cpp/examples/common-ggml.h         examples/common-ggml.h
 cp -rpv ../whisper.cpp/examples/common-ggml.cpp       examples/common-ggml.cpp
+
 cp -rpv ../whisper.cpp/whisper.h                      examples/whisper/whisper.h
 cp -rpv ../whisper.cpp/whisper.cpp                    examples/whisper/whisper.cpp
 cp -rpv ../whisper.cpp/examples/main/main.cpp         examples/whisper/main.cpp
index d88f261449f058d3717ab7de619b14cb8834fa51..06c07339e926999ed1688da0cf40d21ab08ae8ae 100644 (file)
@@ -39,12 +39,6 @@ extern "C" {
 #endif
 #endif
 
-#undef MIN
-#undef MAX
-
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-
 // 16-bit float
 // on Arm, we use __fp16
 // on x86, we use uint16_t
index 9136a7cf6a1fc9b6cbcb2daf3cd2551e82427513..43d0dff09ecb627638318789fc003da21ecee0a0 100644 (file)
@@ -210,7 +210,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         } else {
             GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
 
-            NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+            NSString * sourcePath;
+            NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
+            if (ggmlMetalPathResources) {
+                sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
+            } else {
+                sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+            }
             if (sourcePath == nil) {
                 GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
                 sourcePath = @"ggml-metal.metal";
index 740be6dc5c79811f29de00bb335244939da54fb3..a48eda7320c46d2b39f0d9ed76e14aeee3a2461d 100644 (file)
 //
 #include <arm_neon.h>
 
-#if !defined(__aarch64__)
-inline static int32_t vaddvq_s16(int16x8_t v) {
-    return
-        (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
-        (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
-        (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
-        (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
-}
-
-inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
-    int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
-    int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
-    return vcombine_s16(a0, b0);
-}
-
-inline static int32_t vaddvq_s32(int32x4_t v) {
-    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
-}
-#endif
-
 #else
 
 #ifdef __wasm_simd128__
@@ -47,13 +27,15 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <intrin.h>
 #else
-#if !defined(__riscv) && !defined(__s390__)
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
+#if !defined(__riscv)
 #include <immintrin.h>
 #endif
 #endif
 #endif
 #endif
 #endif
+#endif
 
 #ifdef __riscv_v_intrinsic
 #include <riscv_vector.h>
@@ -61,6 +43,7 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
 
 #undef MIN
 #undef MAX
+
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
@@ -283,9 +266,31 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
 #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 
 #if defined(__ARM_NEON)
-
 #if !defined(__aarch64__)
 
+// 64-bit compatibility
+
+// vaddvq_s16
+// vpaddq_s16
+// vaddvq_s32
+// vaddvq_f32
+// vmaxvq_f32
+// vcvtnq_s32_f32
+
+inline static int32_t vaddvq_s16(int16x8_t v) {
+    return
+        (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
+        (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
+        (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
+        (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
+}
+
+inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
+    int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
+    int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
+    return vcombine_s16(a0, b0);
+}
+
 inline static int32_t vaddvq_s32(int32x4_t v) {
     return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
 }
@@ -311,6 +316,96 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
     return res;
 }
 
+// vld1q_s16_x2
+// vld1q_u8_x2
+// vld1q_u8_x4
+// vld1q_s8_x2
+// vld1q_s8_x4
+// TODO: double-check these work correctly
+
+typedef struct ggml_int16x8x2_t {
+    int16x8_t val[2];
+} ggml_int16x8x2_t;
+
+inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
+    ggml_int16x8x2_t res;
+
+    res.val[0] = vld1q_s16(ptr + 0);
+    res.val[1] = vld1q_s16(ptr + 8);
+
+    return res;
+}
+
+typedef struct ggml_uint8x16x2_t {
+    uint8x16_t val[2];
+} ggml_uint8x16x2_t;
+
+inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
+    ggml_uint8x16x2_t res;
+
+    res.val[0] = vld1q_u8(ptr + 0);
+    res.val[1] = vld1q_u8(ptr + 16);
+
+    return res;
+}
+
+typedef struct ggml_uint8x16x4_t {
+    uint8x16_t val[4];
+} ggml_uint8x16x4_t;
+
+inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
+    ggml_uint8x16x4_t res;
+
+    res.val[0] = vld1q_u8(ptr + 0);
+    res.val[1] = vld1q_u8(ptr + 16);
+    res.val[2] = vld1q_u8(ptr + 32);
+    res.val[3] = vld1q_u8(ptr + 48);
+
+    return res;
+}
+
+typedef struct ggml_int8x16x2_t {
+    int8x16_t val[2];
+} ggml_int8x16x2_t;
+
+inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
+    ggml_int8x16x2_t res;
+
+    res.val[0] = vld1q_s8(ptr + 0);
+    res.val[1] = vld1q_s8(ptr + 16);
+
+    return res;
+}
+
+typedef struct ggml_int8x16x4_t {
+    int8x16_t val[4];
+} ggml_int8x16x4_t;
+
+inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
+    ggml_int8x16x4_t res;
+
+    res.val[0] = vld1q_s8(ptr + 0);
+    res.val[1] = vld1q_s8(ptr + 16);
+    res.val[2] = vld1q_s8(ptr + 32);
+    res.val[3] = vld1q_s8(ptr + 48);
+
+    return res;
+}
+
+#else
+
+#define ggml_int16x8x2_t  int16x8x2_t
+#define ggml_uint8x16x2_t uint8x16x2_t
+#define ggml_uint8x16x4_t uint8x16x4_t
+#define ggml_int8x16x2_t  int8x16x2_t
+#define ggml_int8x16x4_t  int8x16x4_t
+
+#define ggml_vld1q_s16_x2 vld1q_s16_x2
+#define ggml_vld1q_u8_x2  vld1q_u8_x2
+#define ggml_vld1q_u8_x4  vld1q_u8_x4
+#define ggml_vld1q_s8_x2  vld1q_s8_x2
+#define ggml_vld1q_s8_x4  vld1q_s8_x4
+
 #endif
 #endif
 
@@ -3557,7 +3652,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t  vzero = vdupq_n_s32(0);
 #endif
 
-    int8x16x2_t q2bytes;
+    ggml_int8x16x2_t q2bytes;
     uint8_t aux[16];
 
     float sum = 0;
@@ -3576,8 +3671,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
         vst1q_u8(aux, scales);
 
         const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
-        const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
-        const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
+        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+        const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
         const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
                                        vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
         const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
@@ -3605,7 +3700,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 #endif
 
 #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
-        q8bytes = vld1q_s8_x2(q8); q8 += 32;\
+        q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
         q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
         q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
         MULTIPLY_ACCUM_WITH_SCALE((index));
@@ -3613,9 +3708,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 
         for (int j = 0; j < QK_K/128; ++j) {
 
-            const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
+            const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
 
-            int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
             q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
             MULTIPLY_ACCUM_WITH_SCALE(0);
@@ -3949,7 +4044,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t  vzero = vdupq_n_s32(0);
 #endif
 
-    int8x16x4_t q2bytes;
+    ggml_int8x16x4_t q2bytes;
 
     uint32_t aux32[2];
     const uint8_t * scales = (const uint8_t *)aux32;
@@ -3974,7 +4069,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
 
         const uint8x16_t q2bits = vld1q_u8(q2);
 
-        const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+        const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
 
         q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
         q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
@@ -4238,7 +4333,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
     const uint8x16_t m3 = vshlq_n_u8(m0, 3);
     const int8_t m32 = 32;
 
-    int8x16x4_t q3bytes;
+    ggml_int8x16x4_t q3bytes;
 
     float sum = 0;
 
@@ -4250,9 +4345,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
         const uint8_t * restrict qh = x[i].hmask;
         const int8_t  * restrict q8 = y[i].qs;
 
-        uint8x16x2_t qhbits = vld1q_u8_x2(qh);
+        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
 
-        uint8x16x4_t q3h;
+        ggml_uint8x16x4_t q3h;
 
         int32_t isum = 0;
 
@@ -4268,9 +4363,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
 
         for (int j = 0; j < QK_K/128; ++j) {
 
-            const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32;
-            const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
-            const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
+            const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
+            const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
+            const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
 
             q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
             q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
@@ -4772,7 +4867,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
     const uint8x16_t m3b = vdupq_n_u8(0x3);
     const uint8x16_t mh  = vdupq_n_u8(4);
 
-    int8x16x4_t q3bytes;
+    ggml_int8x16x4_t q3bytes;
 
     uint16_t aux16[2];
     int8_t * scales = (int8_t *)aux16;
@@ -4781,11 +4876,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
 
     for (int i = 0; i < nb; ++i) {
 
-        uint8x16x4_t q3h;
+        ggml_uint8x16x4_t q3h;
 
         const uint8x8_t  hbits    = vld1_u8(x[i].hmask);
         const uint8x16_t q3bits   = vld1q_u8(x[i].qs);
-        const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
+        const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs);
 
         const uint16_t a = *(const uint16_t *)x[i].scales;
         aux16[0] = a & 0x0f0f;
@@ -5134,8 +5229,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t mzero = vdupq_n_s32(0);
 #endif
 
-    int8x16x2_t q4bytes;
-    int8x16x2_t q8bytes;
+    ggml_int8x16x2_t q4bytes;
+    ggml_int8x16x2_t q8bytes;
 
     float sumf = 0;
 
@@ -5170,17 +5265,17 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
         for (int j = 0; j < QK_K/64; ++j) {
 
-            const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
+            const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
 
 #ifdef __ARM_FEATURE_DOTPROD
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
             q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
 
             const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
             sumi1 += vaddvq_s32(p1) * scales[2*j+0];
 
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
             q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
 
@@ -5188,7 +5283,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
             sumi2 += vaddvq_s32(p2) * scales[2*j+1];
 #else
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
             q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
             const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@@ -5197,7 +5292,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
                                            vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
             sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
 
-            q8bytes = vld1q_s8_x2(q8); q8 += 32;
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
             q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
             q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
             const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@@ -5512,8 +5607,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
     float sumf = 0;
 
-    int8x16x2_t q4bytes;
-    int8x16x4_t q8bytes;
+    ggml_int8x16x2_t q4bytes;
+    ggml_int8x16x4_t q8bytes;
 
     float sum_mins = 0.f;
 
@@ -5534,10 +5629,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
 
         const float d = y[i].d * (float)x[i].d[0];
 
-        const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
+        const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
 
 #ifdef __ARM_FEATURE_DOTPROD
-        q8bytes = vld1q_s8_x4(q8);
+        q8bytes = ggml_vld1q_s8_x4(q8);
         q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
         q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
 
@@ -5551,7 +5646,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
         const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
 
 #else
-        q8bytes = vld1q_s8_x4(q8);
+        q8bytes = ggml_vld1q_s8_x4(q8);
         q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
         q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
         const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@@ -5785,7 +5880,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t mzero = vdupq_n_s32(0);
 #endif
 
-    int8x16x4_t q5bytes;
+    ggml_int8x16x4_t q5bytes;
 
     float sumf = 0;
 
@@ -5815,16 +5910,16 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
         const uint8_t * restrict qh = x[i].qh;
         const int8_t  * restrict q8 = y[i].qs;
 
-        uint8x16x2_t qhbits = vld1q_u8_x2(qh);
+        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
 
-        uint8x16x4_t q5h;
+        ggml_uint8x16x4_t q5h;
 
         int32_t sumi = 0;
 
         for (int j = 0; j < QK_K/64; ++j) {
 
-            const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
-            const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
+            const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
+            const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
 
             q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
             q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
@@ -6218,8 +6313,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
     const int32x4_t mzero = vdupq_n_s32(0);
 #endif
 
-    int8x16x4_t q5bytes;
-    uint8x16x4_t q5h;
+    ggml_int8x16x4_t q5bytes;
+    ggml_uint8x16x4_t q5h;
 
     float sumf = 0;
 
@@ -6234,8 +6329,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
 
         const uint8x8_t qhbits = vld1_u8(qh);
 
-        const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
-        const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+        const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5);
+        const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
 
         const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
         q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
@@ -6511,8 +6606,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
     const uint8x16_t mone = vdupq_n_u8(3);
 
-    int8x16x4_t q6bytes;
-    uint8x16x4_t q6h;
+    ggml_int8x16x4_t q6bytes;
+    ggml_uint8x16x4_t q6h;
 
     for (int i = 0; i < nb; ++i) {
 
@@ -6524,9 +6619,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
         const int8_t * restrict scale = x[i].scales;
 
-        const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
+        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
         const int8x16_t scales = vld1q_s8(scale);
-        const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
+        const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
 
         const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
                                                    vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
@@ -6538,9 +6633,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
         for (int j = 0; j < QK_K/128; ++j) {
 
-            uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
-            uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
-            int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
+            ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
+            ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
+            ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
 
             q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
             q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
@@ -6583,7 +6678,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
             scale += 2;
 #endif
 
-            q8bytes = vld1q_s8_x4(q8); q8 += 64;
+            q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
 
             shifted = vshrq_n_u8(qhbits.val[0], 4);
             q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
@@ -6987,8 +7082,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
     const uint8x16_t mone = vdupq_n_u8(3);
 
-    int8x16x4_t q6bytes;
-    uint8x16x4_t q6h;
+    ggml_int8x16x4_t q6bytes;
+    ggml_uint8x16x4_t q6h;
 
     for (int i = 0; i < nb; ++i) {
 
@@ -7002,9 +7097,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
 
         int32_t isum = 0;
 
-        uint8x16_t   qhbits = vld1q_u8(qh);
-        uint8x16x2_t q6bits = vld1q_u8_x2(q6);
-        int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+        uint8x16_t qhbits = vld1q_u8(qh);
+        ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6);
+        ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
 
         q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
         uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
index af9f96c3332510abf6378e4f95632b7b02d829f7..018f0ce0bba8ece9504c926820c299d0b3e00055 100644 (file)
@@ -143,6 +143,12 @@ void ggml_print_backtrace(void) {
 }
 #endif
 
+#undef MIN
+#undef MAX
+
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
 /*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
@@ -604,6 +610,18 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
 // simd mappings
 //
 
+#if defined(__ARM_NEON)
+#if !defined(__aarch64__)
+
+// 64-bit compatibility
+
+inline static float vaddvq_f32(float32x4_t v) {
+    return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
+}
+
+#endif
+#endif
+
 // we define a common set of C macros which map to specific intrinsics based on the current architecture
 // we then implement the fundamental computation operations below using only these macros
 // adding support for new architectures requires to define the corresponding SIMD macros