#include <vector>
#include <random>
#include <thread>
+#include <ctime>
+#include <fstream>
#define COMMON_SAMPLE_RATE 16000
std::vector<std::vector<float>> & pcmf32s,
bool stereo);
+// Write PCM data into WAV audio file
+class wav_writer {
+private:
+ std::ofstream file;
+ uint32_t dataSize = 0;
+ std::string wav_filename;
+
+ bool write_header(const uint32_t sample_rate,
+ const uint16_t bits_per_sample,
+ const uint16_t channels) {
+
+ file.write("RIFF", 4);
+ file.write("\0\0\0\0", 4); // Placeholder for file size
+ file.write("WAVE", 4);
+ file.write("fmt ", 4);
+
+ const uint32_t sub_chunk_size = 16;
+ const uint16_t audio_format = 1; // PCM format
+ const uint32_t byte_rate = sample_rate * channels * bits_per_sample / 8;
+ const uint16_t block_align = channels * bits_per_sample / 8;
+
+ file.write(reinterpret_cast<const char *>(&sub_chunk_size), 4);
+ file.write(reinterpret_cast<const char *>(&audio_format), 2);
+ file.write(reinterpret_cast<const char *>(&channels), 2);
+ file.write(reinterpret_cast<const char *>(&sample_rate), 4);
+ file.write(reinterpret_cast<const char *>(&byte_rate), 4);
+ file.write(reinterpret_cast<const char *>(&block_align), 2);
+ file.write(reinterpret_cast<const char *>(&bits_per_sample), 2);
+ file.write("data", 4);
+ file.write("\0\0\0\0", 4); // Placeholder for data size
+
+ return true;
+ }
+
+ // It is assumed that PCM data is normalized to a range from -1 to 1
+ bool write_audio(const float * data, size_t length) {
+ for (size_t i = 0; i < length; ++i) {
+ const auto intSample = static_cast<const int16_t>(data[i] * 32767);
+ file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
+ dataSize += sizeof(int16_t);
+ }
+ if (file.is_open()) {
+ file.seekp(4, std::ios::beg);
+ uint32_t fileSize = 36 + dataSize;
+ file.write(reinterpret_cast<char *>(&fileSize), 4);
+ file.seekp(40, std::ios::beg);
+ file.write(reinterpret_cast<char *>(&dataSize), 4);
+ file.seekp(0, std::ios::end);
+ }
+ return true;
+ }
+
+ bool open_wav(const std::string & filename) {
+ if (filename != wav_filename) {
+ if (file.is_open()) {
+ file.close();
+ }
+ }
+ if (!file.is_open()) {
+ file.open(filename, std::ios::binary);
+ wav_filename = filename;
+ dataSize = 0;
+ }
+ return file.is_open();
+ }
+
+public:
+ bool open(const std::string & filename,
+ const uint32_t sample_rate,
+ const uint16_t bits_per_sample,
+ const uint16_t channels) {
+
+ if (open_wav(filename)) {
+ write_header(sample_rate, bits_per_sample, channels);
+ } else {
+ return false;
+ }
+
+ return true;
+ }
+
+ bool close() {
+ file.close();
+ return true;
+ }
+
+ bool write(const float * data, size_t length) {
+ return write_audio(data, length);
+ }
+
+ ~wav_writer() {
+ if (file.is_open()) {
+ file.close();
+ }
+ }
+};
+
+
// Apply a high-pass frequency filter to PCM audio
// Suppresses frequencies below cutoff Hz
void high_pass_filter(
bool output_wts = false;
bool output_csv = false;
bool output_jsn = false;
+ bool output_jsn_full = false;
bool output_lrc = false;
bool print_special = false;
bool print_colors = false;
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
+ else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
+ fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false");
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
return true;
}
-bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
+bool output_json(
+ struct whisper_context * ctx,
+ const char * fname,
+ const whisper_params & params,
+ std::vector<std::vector<float>> pcmf32s,
+ bool full) {
std::ofstream fout(fname);
int indent = 0;
auto end_arr = [&](bool end) {
indent--;
doindent();
- fout << (end ? "]\n" : "},\n");
+ fout << (end ? "]\n" : "],\n");
};
auto start_obj = [&](const char *name) {
end_value(end);
};
+ auto value_f = [&](const char *name, const float val, bool end) {
+ start_value(name);
+ fout << val;
+ end_value(end);
+ };
+
auto value_b = [&](const char *name, const bool val, bool end) {
start_value(name);
fout << (val ? "true" : "false");
end_value(end);
};
+ auto times_o = [&](int64_t t0, int64_t t1, bool end) {
+ start_obj("timestamps");
+ value_s("from", to_timestamp(t0, true).c_str(), false);
+ value_s("to", to_timestamp(t1, true).c_str(), true);
+ end_obj(false);
+ start_obj("offsets");
+ value_i("from", t0 * 10, false);
+ value_i("to", t1 * 10, true);
+ end_obj(end);
+ };
+
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
start_obj(nullptr);
- start_obj("timestamps");
- value_s("from", to_timestamp(t0, true).c_str(), false);
- value_s("to", to_timestamp(t1, true).c_str(), true);
- end_obj(false);
- start_obj("offsets");
- value_i("from", t0 * 10, false);
- value_i("to", t1 * 10, true);
- end_obj(false);
- value_s("text", text, !params.diarize && !params.tinydiarize);
+ times_o(t0, t1, false);
+ value_s("text", text, !params.diarize && !params.tinydiarize && !full);
+
+ if (full) {
+ start_arr("tokens");
+ const int n = whisper_full_n_tokens(ctx, i);
+ for (int j = 0; j < n; ++j) {
+ auto token = whisper_full_get_token_data(ctx, i, j);
+ start_obj(nullptr);
+ value_s("text", whisper_token_to_str(ctx, token.id), false);
+ if(token.t0 > -1 && token.t1 > -1) {
+ // If we have per-token timestamps, write them out
+ times_o(token.t0, token.t1, false);
+ }
+ value_i("id", token.id, false);
+ value_f("p", token.p, true);
+ end_obj(j == (n - 1));
+ }
+ end_arr(!params.diarize && !params.tinydiarize);
+ }
if (params.diarize && pcmf32s.size() == 2) {
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
- wparams.token_timestamps = params.output_wts || params.max_len > 0;
+ wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
wparams.progress_callback_user_data = &user_data;
}
- // example for abort mechanism
- // in this example, we do not abort the processing, but we could if the flag is set to true
+ // examples for abort mechanism
+ // in examples below, we do not abort the processing, but we could if the flag is set to true
+
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback_user_data = &is_aborted;
}
+ // the callback is called before every computation - if it returns true, the computation is aborted
+ {
+ static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
+
+ wparams.abort_callback = [](void * user_data) {
+ bool is_aborted = *(bool*)user_data;
+ return is_aborted;
+ };
+ wparams.abort_callback_user_data = &is_aborted;
+ }
+
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 10;
// output to JSON file
if (params.output_jsn) {
const auto fname_jsn = fname_out + ".json";
- output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
+ output_json(ctx, fname_jsn.c_str(), params, pcmf32s, params.output_jsn_full);
}
// output to LRC file
//#define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 16
+#define WHISPER_MAX_NODES 4096
//
// ggml helpers
//
-static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
+static void ggml_graph_compute_helper(
+ std::vector<uint8_t> & buf,
+ ggml_cgraph * graph,
+ int n_threads,
+ whisper_abort_callback abort_callback,
+ void * abort_callback_data) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
+ plan.abort_callback = abort_callback;
+ plan.abort_callback_data = abort_callback_data;
+
if (plan.work_size > 0) {
buf.resize(plan.work_size);
plan.work_data = buf.data();
auto & meta = allocr.meta;
auto & data = allocr.data;
- meta.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
+ meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
alloc = ggml_allocr_new_measure(tensor_alignment);
struct ggml_context * ctx0 = ggml_init(params);
- ggml_cgraph * gf = ggml_new_graph(ctx0);
+ ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
ggml_allocr * alloc = wstate.alloc_encode.alloc;
whisper_context & wctx,
whisper_state & wstate,
const int mel_offset,
- const int n_threads) {
+ const int n_threads,
+ whisper_abort_callback abort_callback,
+ void * abort_callback_data) {
const int64_t t_start_us = ggml_time_us();
// conv
ggml_allocr_alloc_graph(alloc, gf);
if (!whisper_encode_external(wstate)) {
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
}
}
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
} else {
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
}
#else
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
#endif
}
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
} else {
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
}
#else
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
#endif
}
struct ggml_context * ctx0 = ggml_init(params);
- ggml_cgraph * gf = ggml_new_graph(ctx0);
+ ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
ggml_allocr * alloc = wstate.alloc_decode.alloc;
const whisper_token * tokens,
const int n_tokens,
const int n_past,
- const int n_threads) {
+ const int n_threads,
+ whisper_abort_callback abort_callback,
+ void * abort_callback_data) {
const int64_t t_start_us = ggml_time_us();
const auto & model = wctx.model;
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
} else {
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
}
#else
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
#endif
}
}
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
- if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
+ if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
log("%s: failed to eval\n", __func__);
return -1;
}
}
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
- if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
+ if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
log("%s: failed to eval\n", __func__);
return -1;
}
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
const int selected_decoder_id = 0;
- if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
log("%s: failed to eval\n", __func__);
return 1;
}
return false;
}
- if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
+ if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
log("%s: failed to eval\n", __func__);
return 1;
}
}
// encode audio features starting at offset seek
- if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
+ if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
log("%s: failed to encode\n", __func__);
return -6;
}
}
WHISPER_PRINT_DEBUG("\n\n");
- if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
log("%s: failed to decode\n", __func__);
return -7;
}
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
- if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
+ if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
log("%s: failed to decode\n", __func__);
return -8;
}
return ctx->state->result_all[i_segment].t1;
}
+bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
+ return state->result_all[i_segment].speaker_turn_next;
+}
+
bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
return ctx->state->result_all[i_segment].speaker_turn_next;
}
double tsum = 0.0;
// heat-up
- ggml_graph_compute_helper(work, gf, n_threads);
+ ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
for (int i = 0; i < n_max; ++i) {
const int64_t t0 = ggml_time_us();
- ggml_graph_compute_helper(work, gf, n_threads);
+ ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
const int64_t t1 = ggml_time_us();
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
+ // Abort callback
+ // If not NULL, called before ggml computation
+ // If it returns true, the computation is aborted
+ typedef bool (*whisper_abort_callback)(void * user_data);
+
// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
+ // called each time before ggml computation starts
+ whisper_abort_callback abort_callback;
+ void * abort_callback_user_data;
+
// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
// Get whether the next segment is predicted as a speaker turn
WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
+ WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment);
// Get the text of the specified segment
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);
#!/bin/bash
cp -rpv ../whisper.cpp/ggml.c src/ggml.c
+cp -rpv ../whisper.cpp/ggml-impl.h src/ggml-impl.h
cp -rpv ../whisper.cpp/ggml-alloc.c src/ggml-alloc.c
-cp -rpv ../whisper.cpp/ggml-cuda.h src/ggml-cuda.h
+cp -rpv ../whisper.cpp/ggml-backend-impl.h src/ggml-backend-impl.h
+cp -rpv ../whisper.cpp/ggml-backend.c src/ggml-backend.c
cp -rpv ../whisper.cpp/ggml-cuda.cu src/ggml-cuda.cu
-cp -rpv ../whisper.cpp/ggml-opencl.h src/ggml-opencl.h
-cp -rpv ../whisper.cpp/ggml-opencl.cpp src/ggml-opencl.cpp
+cp -rpv ../whisper.cpp/ggml-cuda.h src/ggml-cuda.h
cp -rpv ../whisper.cpp/ggml-metal.h src/ggml-metal.h
cp -rpv ../whisper.cpp/ggml-metal.m src/ggml-metal.m
cp -rpv ../whisper.cpp/ggml-metal.metal src/ggml-metal.metal
+#cp -rpv ../whisper.cpp/ggml-mpi.h src/ggml-mpi.h
+#cp -rpv ../whisper.cpp/ggml-mpi.m src/ggml-mpi.m
+cp -rpv ../whisper.cpp/ggml-opencl.cpp src/ggml-opencl.cpp
+cp -rpv ../whisper.cpp/ggml-opencl.h src/ggml-opencl.h
+cp -rpv ../whisper.cpp/ggml-quants.c src/ggml-quants.c
+cp -rpv ../whisper.cpp/ggml-quants.h src/ggml-quants.h
+
cp -rpv ../whisper.cpp/ggml.h include/ggml/ggml.h
cp -rpv ../whisper.cpp/ggml-alloc.h include/ggml/ggml-alloc.h
+cp -rpv ../whisper.cpp/ggml-backend.h include/ggml/ggml-backend.h
+
cp -rpv ../whisper.cpp/examples/common.h examples/common.h
cp -rpv ../whisper.cpp/examples/common.cpp examples/common.cpp
cp -rpv ../whisper.cpp/examples/common-ggml.h examples/common-ggml.h
cp -rpv ../whisper.cpp/examples/common-ggml.cpp examples/common-ggml.cpp
+
cp -rpv ../whisper.cpp/whisper.h examples/whisper/whisper.h
cp -rpv ../whisper.cpp/whisper.cpp examples/whisper/whisper.cpp
cp -rpv ../whisper.cpp/examples/main/main.cpp examples/whisper/main.cpp
#endif
#endif
-#undef MIN
-#undef MAX
-
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-
// 16-bit float
// on Arm, we use __fp16
// on x86, we use uint16_t
} else {
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
- NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+ NSString * sourcePath;
+ NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
+ if (ggmlMetalPathResources) {
+ sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
+ } else {
+ sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+ }
if (sourcePath == nil) {
GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
sourcePath = @"ggml-metal.metal";
//
#include <arm_neon.h>
-#if !defined(__aarch64__)
-inline static int32_t vaddvq_s16(int16x8_t v) {
- return
- (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
- (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
- (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
- (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
-}
-
-inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
- int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
- int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
- return vcombine_s16(a0, b0);
-}
-
-inline static int32_t vaddvq_s32(int32x4_t v) {
- return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
-}
-#endif
-
#else
#ifdef __wasm_simd128__
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <intrin.h>
#else
-#if !defined(__riscv) && !defined(__s390__)
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
+#if !defined(__riscv)
#include <immintrin.h>
#endif
#endif
#endif
#endif
#endif
+#endif
#ifdef __riscv_v_intrinsic
#include <riscv_vector.h>
#undef MIN
#undef MAX
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
#if defined(__ARM_NEON)
-
#if !defined(__aarch64__)
+// 64-bit compatibility
+
+// vaddvq_s16
+// vpaddq_s16
+// vaddvq_s32
+// vaddvq_f32
+// vmaxvq_f32
+// vcvtnq_s32_f32
+
+inline static int32_t vaddvq_s16(int16x8_t v) {
+ return
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
+}
+
+inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
+ int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
+ int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
+ return vcombine_s16(a0, b0);
+}
+
inline static int32_t vaddvq_s32(int32x4_t v) {
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
}
return res;
}
+// vld1q_s16_x2
+// vld1q_u8_x2
+// vld1q_u8_x4
+// vld1q_s8_x2
+// vld1q_s8_x4
+// TODO: double-check these work correctly
+
+typedef struct ggml_int16x8x2_t {
+ int16x8_t val[2];
+} ggml_int16x8x2_t;
+
+inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
+ ggml_int16x8x2_t res;
+
+ res.val[0] = vld1q_s16(ptr + 0);
+ res.val[1] = vld1q_s16(ptr + 8);
+
+ return res;
+}
+
+typedef struct ggml_uint8x16x2_t {
+ uint8x16_t val[2];
+} ggml_uint8x16x2_t;
+
+inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
+ ggml_uint8x16x2_t res;
+
+ res.val[0] = vld1q_u8(ptr + 0);
+ res.val[1] = vld1q_u8(ptr + 16);
+
+ return res;
+}
+
+typedef struct ggml_uint8x16x4_t {
+ uint8x16_t val[4];
+} ggml_uint8x16x4_t;
+
+inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
+ ggml_uint8x16x4_t res;
+
+ res.val[0] = vld1q_u8(ptr + 0);
+ res.val[1] = vld1q_u8(ptr + 16);
+ res.val[2] = vld1q_u8(ptr + 32);
+ res.val[3] = vld1q_u8(ptr + 48);
+
+ return res;
+}
+
+typedef struct ggml_int8x16x2_t {
+ int8x16_t val[2];
+} ggml_int8x16x2_t;
+
+inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
+ ggml_int8x16x2_t res;
+
+ res.val[0] = vld1q_s8(ptr + 0);
+ res.val[1] = vld1q_s8(ptr + 16);
+
+ return res;
+}
+
+typedef struct ggml_int8x16x4_t {
+ int8x16_t val[4];
+} ggml_int8x16x4_t;
+
+inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
+ ggml_int8x16x4_t res;
+
+ res.val[0] = vld1q_s8(ptr + 0);
+ res.val[1] = vld1q_s8(ptr + 16);
+ res.val[2] = vld1q_s8(ptr + 32);
+ res.val[3] = vld1q_s8(ptr + 48);
+
+ return res;
+}
+
+#else
+
+#define ggml_int16x8x2_t int16x8x2_t
+#define ggml_uint8x16x2_t uint8x16x2_t
+#define ggml_uint8x16x4_t uint8x16x4_t
+#define ggml_int8x16x2_t int8x16x2_t
+#define ggml_int8x16x4_t int8x16x4_t
+
+#define ggml_vld1q_s16_x2 vld1q_s16_x2
+#define ggml_vld1q_u8_x2 vld1q_u8_x2
+#define ggml_vld1q_u8_x4 vld1q_u8_x4
+#define ggml_vld1q_s8_x2 vld1q_s8_x2
+#define ggml_vld1q_s8_x4 vld1q_s8_x4
+
#endif
#endif
const int32x4_t vzero = vdupq_n_s32(0);
#endif
- int8x16x2_t q2bytes;
+ ggml_int8x16x2_t q2bytes;
uint8_t aux[16];
float sum = 0;
vst1q_u8(aux, scales);
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
- const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
- const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
+ const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+ const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
#endif
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
- q8bytes = vld1q_s8_x2(q8); q8 += 32;\
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
MULTIPLY_ACCUM_WITH_SCALE((index));
for (int j = 0; j < QK_K/128; ++j) {
- const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
+ const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
- int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
+ ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
MULTIPLY_ACCUM_WITH_SCALE(0);
const int32x4_t vzero = vdupq_n_s32(0);
#endif
- int8x16x4_t q2bytes;
+ ggml_int8x16x4_t q2bytes;
uint32_t aux32[2];
const uint8_t * scales = (const uint8_t *)aux32;
const uint8x16_t q2bits = vld1q_u8(q2);
- const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+ const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
const uint8x16_t m3 = vshlq_n_u8(m0, 3);
const int8_t m32 = 32;
- int8x16x4_t q3bytes;
+ ggml_int8x16x4_t q3bytes;
float sum = 0;
const uint8_t * restrict qh = x[i].hmask;
const int8_t * restrict q8 = y[i].qs;
- uint8x16x2_t qhbits = vld1q_u8_x2(qh);
+ ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
- uint8x16x4_t q3h;
+ ggml_uint8x16x4_t q3h;
int32_t isum = 0;
for (int j = 0; j < QK_K/128; ++j) {
- const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32;
- const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
- const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
+ const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
+ const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
+ const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
const uint8x16_t m3b = vdupq_n_u8(0x3);
const uint8x16_t mh = vdupq_n_u8(4);
- int8x16x4_t q3bytes;
+ ggml_int8x16x4_t q3bytes;
uint16_t aux16[2];
int8_t * scales = (int8_t *)aux16;
for (int i = 0; i < nb; ++i) {
- uint8x16x4_t q3h;
+ ggml_uint8x16x4_t q3h;
const uint8x8_t hbits = vld1_u8(x[i].hmask);
const uint8x16_t q3bits = vld1q_u8(x[i].qs);
- const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
+ const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs);
const uint16_t a = *(const uint16_t *)x[i].scales;
aux16[0] = a & 0x0f0f;
const int32x4_t mzero = vdupq_n_s32(0);
#endif
- int8x16x2_t q4bytes;
- int8x16x2_t q8bytes;
+ ggml_int8x16x2_t q4bytes;
+ ggml_int8x16x2_t q8bytes;
float sumf = 0;
for (int j = 0; j < QK_K/64; ++j) {
- const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
+ const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
#ifdef __ARM_FEATURE_DOTPROD
- q8bytes = vld1q_s8_x2(q8); q8 += 32;
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
- q8bytes = vld1q_s8_x2(q8); q8 += 32;
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
#else
- q8bytes = vld1q_s8_x2(q8); q8 += 32;
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
- q8bytes = vld1q_s8_x2(q8); q8 += 32;
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
float sumf = 0;
- int8x16x2_t q4bytes;
- int8x16x4_t q8bytes;
+ ggml_int8x16x2_t q4bytes;
+ ggml_int8x16x4_t q8bytes;
float sum_mins = 0.f;
const float d = y[i].d * (float)x[i].d[0];
- const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
+ const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
#ifdef __ARM_FEATURE_DOTPROD
- q8bytes = vld1q_s8_x4(q8);
+ q8bytes = ggml_vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
#else
- q8bytes = vld1q_s8_x4(q8);
+ q8bytes = ggml_vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
const int32x4_t mzero = vdupq_n_s32(0);
#endif
- int8x16x4_t q5bytes;
+ ggml_int8x16x4_t q5bytes;
float sumf = 0;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
- uint8x16x2_t qhbits = vld1q_u8_x2(qh);
+ ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
- uint8x16x4_t q5h;
+ ggml_uint8x16x4_t q5h;
int32_t sumi = 0;
for (int j = 0; j < QK_K/64; ++j) {
- const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
- const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
+ const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
+ const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
const int32x4_t mzero = vdupq_n_s32(0);
#endif
- int8x16x4_t q5bytes;
- uint8x16x4_t q5h;
+ ggml_int8x16x4_t q5bytes;
+ ggml_uint8x16x4_t q5h;
float sumf = 0;
const uint8x8_t qhbits = vld1_u8(qh);
- const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
- const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+ const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5);
+ const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
const uint8x16_t mone = vdupq_n_u8(3);
- int8x16x4_t q6bytes;
- uint8x16x4_t q6h;
+ ggml_int8x16x4_t q6bytes;
+ ggml_uint8x16x4_t q6h;
for (int i = 0; i < nb; ++i) {
const int8_t * restrict scale = x[i].scales;
- const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
+ const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
const int8x16_t scales = vld1q_s8(scale);
- const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
+ const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
for (int j = 0; j < QK_K/128; ++j) {
- uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
- uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
- int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
+ ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
+ ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
+ ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
scale += 2;
#endif
- q8bytes = vld1q_s8_x4(q8); q8 += 64;
+ q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
shifted = vshrq_n_u8(qhbits.val[0], 4);
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
const uint8x16_t mone = vdupq_n_u8(3);
- int8x16x4_t q6bytes;
- uint8x16x4_t q6h;
+ ggml_int8x16x4_t q6bytes;
+ ggml_uint8x16x4_t q6h;
for (int i = 0; i < nb; ++i) {
int32_t isum = 0;
- uint8x16_t qhbits = vld1q_u8(qh);
- uint8x16x2_t q6bits = vld1q_u8_x2(q6);
- int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+ uint8x16_t qhbits = vld1q_u8(qh);
+ ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6);
+ ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
}
#endif
+#undef MIN
+#undef MAX
+
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
/*#define GGML_PERF*/
#define GGML_DEBUG 0
#define GGML_GELU_FP16
// simd mappings
//
+#if defined(__ARM_NEON)
+#if !defined(__aarch64__)
+
+// 64-bit compatibility
+
+inline static float vaddvq_f32(float32x4_t v) {
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
+}
+
+#endif
+#endif
+
// we define a common set of C macros which map to specific intrinsics based on the current architecture
// we then implement the fundamental computation operations below using only these macros
// adding support for new architectures requires to define the corresponding SIMD macros