const int n_segments = whisper_full_n_segments(ctx);
+ std::string speaker = "";
+
+ int64_t t0;
+ int64_t t1;
+
// print the last n_new segments
const int s0 = n_segments - n_new;
+
if (s0 == 0) {
printf("\n");
}
for (int i = s0; i < n_segments; i++) {
- if (params.no_timestamps) {
- if (params.print_colors) {
- for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
- if (params.print_special == false) {
- const whisper_token id = whisper_full_get_token_id(ctx, i, j);
- if (id >= whisper_token_eot(ctx)) {
- continue;
- }
- }
-
- const char * text = whisper_full_get_token_text(ctx, i, j);
- const float p = whisper_full_get_token_p (ctx, i, j);
+ if (!params.no_timestamps || params.diarize) {
+ t0 = whisper_full_get_segment_t0(ctx, i);
+ t1 = whisper_full_get_segment_t1(ctx, i);
+ }
- const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+ if (!params.no_timestamps) {
+ printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
+ }
- printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
- }
- } else {
- const char * text = whisper_full_get_segment_text(ctx, i);
- printf("%s", text);
- }
- fflush(stdout);
- } else {
- const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
- const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+ if (params.diarize && pcmf32s.size() == 2) {
+ const int64_t n_samples = pcmf32s[0].size();
- std::string speaker;
+ const int64_t is0 = timestamp_to_sample(t0, n_samples);
+ const int64_t is1 = timestamp_to_sample(t1, n_samples);
- if (params.diarize && pcmf32s.size() == 2) {
- const int64_t n_samples = pcmf32s[0].size();
+ double energy0 = 0.0f;
+ double energy1 = 0.0f;
- const int64_t is0 = timestamp_to_sample(t0, n_samples);
- const int64_t is1 = timestamp_to_sample(t1, n_samples);
+ for (int64_t j = is0; j < is1; j++) {
+ energy0 += fabs(pcmf32s[0][j]);
+ energy1 += fabs(pcmf32s[1][j]);
+ }
- double energy0 = 0.0f;
- double energy1 = 0.0f;
+ if (energy0 > 1.1*energy1) {
+ speaker = "(speaker 0)";
+ } else if (energy1 > 1.1*energy0) {
+ speaker = "(speaker 1)";
+ } else {
+ speaker = "(speaker ?)";
+ }
- for (int64_t j = is0; j < is1; j++) {
- energy0 += fabs(pcmf32s[0][j]);
- energy1 += fabs(pcmf32s[1][j]);
- }
+ //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
+ }
- if (energy0 > 1.1*energy1) {
- speaker = "(speaker 0)";
- } else if (energy1 > 1.1*energy0) {
- speaker = "(speaker 1)";
- } else {
- speaker = "(speaker ?)";
+ if (params.print_colors) {
+ for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
+ if (params.print_special == false) {
+ const whisper_token id = whisper_full_get_token_id(ctx, i, j);
+ if (id >= whisper_token_eot(ctx)) {
+ continue;
+ }
}
- //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
- }
-
- if (params.print_colors) {
- printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
- for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
- if (params.print_special == false) {
- const whisper_token id = whisper_full_get_token_id(ctx, i, j);
- if (id >= whisper_token_eot(ctx)) {
- continue;
- }
- }
+ const char * text = whisper_full_get_token_text(ctx, i, j);
+ const float p = whisper_full_get_token_p (ctx, i, j);
- const char * text = whisper_full_get_token_text(ctx, i, j);
- const float p = whisper_full_get_token_p (ctx, i, j);
+ const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
- const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+ printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
+ }
+ } else {
+ const char * text = whisper_full_get_segment_text(ctx, i);
- printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
- }
- printf("\n");
- } else {
- const char * text = whisper_full_get_segment_text(ctx, i);
+ printf("%s%s", speaker.c_str(), text);
+ }
- printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
- }
+ // with timestamps or speakers: each segment on new line
+ if (!params.no_timestamps || params.diarize) {
+ printf("\n");
}
+
+ fflush(stdout);
}
}
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
- fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
+ fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return 8;
}
std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer;
+ ggml_type wtype; // weight type (FP32 or FP16)
+
whisper_model model;
whisper_vocab vocab;
};
template<typename T>
-static void read_safe(std::ifstream& fin, T& dest)
-{
- fin.read((char*)& dest, sizeof(T));
+static void read_safe(std::ifstream& fin, T& dest) {
+ fin.read((char*)& dest, sizeof(T));
}
// load the model from a ggml file
// for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation
- const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+ wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+ const ggml_type wtype = wctx.wtype;
size_t ctx_size = 0;
// encoder
{
- // TODO: F16 .. maybe not?
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
// decoder
{
- // TODO: F16 .. maybe not?
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;
- model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
- model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
+ model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
}
// key/value memory for the cross-attention layer
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
- model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
- model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
+ model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
}
const size_t memory_size =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+ ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+ ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * V =
Vcur,
n_state/n_head, n_head, n_ctx),
1, 2, 0, 3),
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
+ ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+ ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
// K * Q
// ggml_permute(ctxL,
// ggml_cpy(ctxL,
// Vcur,
- // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+ // ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3);
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
Vcur,
n_state/n_head, n_head, n_ctx),
0, 2, 1, 3),
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
+ ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
#ifdef USE_FLASH_FF
cur = ggml_flash_ff(ctxL,
- ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
+ ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else
// fully connected
}
{
- for (int i = 0; i < (int) probs_id.size(); i++) {
+ for (const auto & prob : probs_id) {
if (lang_probs) {
- lang_probs[probs_id[i].second] = probs_id[i].first;
+ lang_probs[prob.second] = prob.first;
}
- //printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
+ //printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first);
}
}
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
+ s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
+ s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
return s.c_str();
}
// separate key + value memory for each processor
{
- auto & ctx = model.ctx_mem;
+ auto & mctx = model.ctx_mem;
const auto & hparams = model.hparams;
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;
- model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
- model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
+ model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
}
// key/value memory for the cross-attention layer
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
- model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
- model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
+ model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
}
}
}
for (int i = 0; i < n_processors - 1; ++i) {
auto & results_i = ctxs[i].result_all;
- for (int j = 0; j < (int) results_i.size(); ++j) {
+ for (auto & result : results_i) {
// correct the segment timestamp taking into account the offset
- results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
- results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
+ result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
+ result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
// make sure that segments are not overlapping
if (!ctx->result_all.empty()) {
- results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
+ result.t0 = std::max(result.t0, ctx->result_all.back().t1);
}
- ctx->result_all.push_back(std::move(results_i[j]));
+ ctx->result_all.push_back(std::move(result));
// call the new_segment_callback for each segment
if (params.new_segment_callback) {
static float voice_length(const std::string & text) {
float res = 0.0f;
- for (size_t i = 0; i < text.size(); ++i) {
- if (text[i] == ' ') {
+ for (char c : text) {
+ if (c == ' ') {
res += 0.01f;
- } else if (text[i] == ',') {
+ } else if (c == ',') {
res += 2.00f;
- } else if (text[i] == '.') {
+ } else if (c == '.') {
res += 3.00f;
- } else if (text[i] == '!') {
+ } else if (c == '!') {
res += 3.00f;
- } else if (text[i] == '?') {
+ } else if (c == '?') {
res += 3.00f;
- } else if (text[i] >= '0' && text[i] <= '9') {
+ } else if (c >= '0' && c <= '9') {
res += 3.00f;
} else {
res += 1.00f;
#define GGML_MEM_ALIGN 16
#endif
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-
#define UNUSED(x) (void)(x)
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
#include <cblas.h>
#endif
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
// floating point type used to accumulate sums
typedef double ggml_float;
//
#include <arm_neon.h>
-float ggml_fp16_to_fp32(ggml_fp16_t x) {
- return x;
-}
-
-ggml_fp16_t ggml_fp32_to_fp16(float x) {
- return x;
-}
+#define GGML_COMPUTE_FP16_TO_FP32(x) (x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) (x)
#define GGML_FP16_TO_FP32(x) (x)
#define GGML_FP32_TO_FP16(x) (x)
#endif
#ifdef __F16C__
-float ggml_fp16_to_fp32(ggml_fp16_t h) {
- return _cvtsh_ss(h);
-}
-ggml_fp16_t ggml_fp32_to_fp16(float f) {
- return _cvtss_sh(f, 0);
-}
-#define GGML_FP16_TO_FP32(x) _cvtsh_ss(x)
-#define GGML_FP32_TO_FP16(x) _cvtss_sh(x, 0)
+#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
#else
return fp32.as_bits;
}
-float ggml_fp16_to_fp32(ggml_fp16_t h) {
+static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
const uint32_t w = (uint32_t) h << 16;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t two_w = w + w;
return fp32_from_bits(result);
}
-ggml_fp16_t ggml_fp32_to_fp16(float f) {
+static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
const float scale_to_inf = 0x1.0p+112f;
const float scale_to_zero = 0x1.0p-110f;
return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
}
-#define GGML_FP16_TO_FP32(x) ggml_fp16_to_fp32(x)
-#define GGML_FP32_TO_FP16(x) ggml_fp32_to_fp16(x)
+#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
#endif // __F16C__
// precomputed exp table for f16 (128 KB)
static ggml_fp16_t table_exp_f16[1 << 16];
+// precomputed f32 table for f16 (256 KB)
+static float table_f32_f16[1 << 16];
+
+// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
+// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
+#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
+
+inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
+ uint16_t s;
+ memcpy(&s, &f, sizeof(uint16_t));
+ return table_f32_f16[s];
+}
+
+#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
+#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+
+#endif
+
+// note: do not use these inside ggml.c
+// these are meant to be used via the ggml.h API
+float ggml_fp16_to_fp32(ggml_fp16_t x) {
+ return GGML_FP16_TO_FP32(x);
+}
+
+ggml_fp16_t ggml_fp32_to_fp16(float x) {
+ return GGML_FP32_TO_FP16(x);
+}
+
//
// timing
//
res = vaddvq_f32(vaddq_f32(t0, t1)); \
}
- #define GGML_F16_VEC GGML_F16x8
- #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
- #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
- #define GGML_F16_VEC_LOAD GGML_F16x8_LOAD
- #define GGML_F16_VEC_STORE GGML_F16x8_STORE
- #define GGML_F16_VEC_FMA GGML_F16x8_FMA
- #define GGML_F16_VEC_ADD GGML_F16x8_ADD
- #define GGML_F16_VEC_MUL GGML_F16x8_MUL
- #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
+ #define GGML_F16_VEC GGML_F16x8
+ #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
+ #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
+ #define GGML_F16_VEC_FMA GGML_F16x8_FMA
+ #define GGML_F16_VEC_ADD GGML_F16x8_ADD
+ #define GGML_F16_VEC_MUL GGML_F16x8_MUL
+ #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
#else
// if FP16 vector arithmetic is not supported, we use FP32 instead
// and take advantage of the vcvt_ functions to convert to/from FP16
#define GGML_F32Cx4_MUL vmulq_f32
#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
- #define GGML_F16_VEC GGML_F32Cx4
- #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
- #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
- #define GGML_F16_VEC_LOAD GGML_F32Cx4_LOAD
- #define GGML_F16_VEC_STORE GGML_F32Cx4_STORE
- #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
- #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
- #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
- #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
+ #define GGML_F16_VEC GGML_F32Cx4
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
#endif
#elif defined(__AVX__)
#define GGML_F32Cx8_MUL _mm256_mul_ps
#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
-#define GGML_F16_VEC GGML_F32Cx8
-#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
-#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
-#define GGML_F16_VEC_LOAD GGML_F32Cx8_LOAD
-#define GGML_F16_VEC_STORE GGML_F32Cx8_STORE
-#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
-#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
-#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
-#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
+#define GGML_F16_VEC GGML_F32Cx8
+#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
+#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
#elif defined(__POWER9_VECTOR__)
-// TODO: uncomment this when it works
-//#define GGML_SIMD
+#define GGML_SIMD
// F32 POWER9
#define GGML_F32_STEP 32
-#define GGML_F32_EPR 8
+#define GGML_F32_EPR 4
-// TODO: not tested !!
-#define GGML_F32x4 __vector float
-#define GGML_F32x4_ZERO (__vector float){0.0f, 0.0f, 0.0f, 0.0f}
-#define GGML_F32x4_SET1(x) (__vector float){x, x, x, x}
-#define GGML_F32x4_LOAD vec_vsx_ld
-#define GGML_F32x4_STORE vec_vsx_st
+#define GGML_F32x4 vector float
+#define GGML_F32x4_ZERO 0.0f
+#define GGML_F32x4_SET1 vec_splats
+#define GGML_F32x4_LOAD(p) vec_xl(0, p)
+#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
-#define GGML_F32x4_ADD vec_add
-#define GGML_F32x4_MUL vec_mul
+#define GGML_F32x4_ADD vec_add
+#define GGML_F32x4_MUL vec_mul
#define GGML_F32x4_REDUCE(res, x) \
{ \
for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
// F16 POWER9
-// TODO: implement here
-// ...
+#define GGML_F16_STEP GGML_F32_STEP
+#define GGML_F16_EPR GGML_F32_EPR
+#define GGML_F16_VEC GGML_F32x4
+#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F16_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F16_VEC_FMA GGML_F32x4_FMA
+#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
+// Use vec_xl, not vec_ld, in case the load address is not aligned.
+#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
+ vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
+ vec_extract_fp32_from_shortl(vec_xl(0, p))
+#define GGML_F16_VEC_STORE(p, r, i) \
+ if (i & 0x1) \
+ vec_xst(vec_pack_to_short_fp32(r[i], r[i - 1]), 0, p - GGML_F16_EPR)
#elif defined(__wasm_simd128__)
wasm_f32x4_extract_lane(x[0], 3); \
}
-#define GGML_F16_VEC GGML_F16x4
-#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO
-#define GGML_F16_VEC_SET1 GGML_F16x4_SET1
-#define GGML_F16_VEC_LOAD GGML_F16x4_LOAD
-#define GGML_F16_VEC_STORE GGML_F16x4_STORE
-#define GGML_F16_VEC_FMA GGML_F16x4_FMA
-#define GGML_F16_VEC_ADD GGML_F16x4_ADD
-#define GGML_F16_VEC_MUL GGML_F16x4_MUL
-#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
+#define GGML_F16_VEC GGML_F16x4
+#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO
+#define GGML_F16_VEC_SET1 GGML_F16x4_SET1
+#define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA GGML_F16x4_FMA
+#define GGML_F16_VEC_ADD GGML_F16x4_ADD
+#define GGML_F16_VEC_MUL GGML_F16x4_MUL
+#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
+
+#elif defined(__SSE3__)
+
+#define GGML_SIMD
+
+// F32 SSE
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR 4
+
+#define GGML_F32x4 __m128
+#define GGML_F32x4_ZERO _mm_setzero_ps()
+#define GGML_F32x4_SET1(x) _mm_set1_ps(x)
+#define GGML_F32x4_LOAD _mm_loadu_ps
+#define GGML_F32x4_STORE _mm_storeu_ps
+#if defined(__FMA__)
+ // TODO: Does this work?
+ #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
+#else
+ #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
+#endif
+#define GGML_F32x4_ADD _mm_add_ps
+#define GGML_F32x4_MUL _mm_mul_ps
+#define GGML_F32x4_REDUCE(res, x) \
+{ \
+ for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
+ x[2*i] = _mm_add_ps(x[2*i], x[2*i+1]); \
+ } \
+ for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
+ x[4*i] = _mm_add_ps(x[4*i], x[4*i+2]); \
+ } \
+ for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
+ x[8*i] = _mm_add_ps(x[8*i], x[8*i+4]); \
+ } \
+ const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
+ res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
+}
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC GGML_F32x4
+#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 SSE
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR 4
+
+static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
+ float tmp[4];
+
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
+
+ return _mm_loadu_ps(tmp);
+}
+
+static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
+ float arr[4];
+
+ _mm_storeu_ps(arr, y);
+
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
+}
+
+#define GGML_F32Cx4 __m128
+#define GGML_F32Cx4_ZERO _mm_setzero_ps()
+#define GGML_F32Cx4_SET1(x) _mm_set1_ps(x)
+#define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x)
+#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
+#define GGML_F32Cx4_FMA GGML_F32x4_FMA
+#define GGML_F32Cx4_ADD _mm_add_ps
+#define GGML_F32Cx4_MUL _mm_mul_ps
+#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
+
+#define GGML_F16_VEC GGML_F32Cx4
+#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
+#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
+#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
+#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
+#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
#endif
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
- ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR);
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR);
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
}
for (int i = np; i < n; ++i) {
sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
}
-#elif defined(__POWER9_VECTOR__)
- // TODO: this is temporary because I cannot fit it in the GGML_SIMD pattern like all other architectures without
- // being able to test it. hoping someone with access to a POWER9 machine can help out here.
- const int n32 = (n & ~31);
-
- vector float sum0 = vec_splats (0.0f);
-
- for (int i = 0; i < n32; i += 32) {
- // Use vec_xl, not vec_ld, because x is sometimes unaligned.
- vector unsigned short x0 = vec_xl(i * 2 + 0, x);
- vector unsigned short x1 = vec_xl(i * 2 + 16, x);
- vector unsigned short x2 = vec_xl(i * 2 + 32, x);
- vector unsigned short x3 = vec_xl(i * 2 + 48, x);
-
- vector unsigned short y0 = vec_xl(i * 2 + 0, y);
- vector unsigned short y1 = vec_xl(i * 2 + 16, y);
- vector unsigned short y2 = vec_xl(i * 2 + 32, y);
- vector unsigned short y3 = vec_xl(i * 2 + 48, y);
-
- vector float fx0l = vec_extract_fp32_from_shortl(x0);
- vector float fx0h = vec_extract_fp32_from_shorth(x0);
- vector float fx1l = vec_extract_fp32_from_shortl(x1);
- vector float fx1h = vec_extract_fp32_from_shorth(x1);
- vector float fx2l = vec_extract_fp32_from_shortl(x2);
- vector float fx2h = vec_extract_fp32_from_shorth(x2);
- vector float fx3l = vec_extract_fp32_from_shortl(x3);
- vector float fx3h = vec_extract_fp32_from_shorth(x3);
-
- vector float fy0l = vec_extract_fp32_from_shortl(y0);
- vector float fy0h = vec_extract_fp32_from_shorth(y0);
- vector float fy1l = vec_extract_fp32_from_shortl(y1);
- vector float fy1h = vec_extract_fp32_from_shorth(y1);
- vector float fy2l = vec_extract_fp32_from_shortl(y2);
- vector float fy2h = vec_extract_fp32_from_shorth(y2);
- vector float fy3l = vec_extract_fp32_from_shortl(y3);
- vector float fy3h = vec_extract_fp32_from_shorth(y3);
-
- sum0 = vec_add(sum0, vec_mul(fx0l, fy0l));
- sum0 = vec_add(sum0, vec_mul(fx0h, fy0h));
- sum0 = vec_add(sum0, vec_mul(fx1l, fy1l));
- sum0 = vec_add(sum0, vec_mul(fx1h, fy1h));
- sum0 = vec_add(sum0, vec_mul(fx2l, fy2l));
- sum0 = vec_add(sum0, vec_mul(fx2h, fy2h));
- sum0 = vec_add(sum0, vec_mul(fx3l, fy3l));
- sum0 = vec_add(sum0, vec_mul(fx3h, fy3h));
- }
-
- sumf = vec_extract(sum0, 0) + vec_extract(sum0, 1)
- + vec_extract(sum0, 2) + vec_extract(sum0, 3);
-
- for (int i = n32; i < n; ++i) {
- sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
- }
#else
for (int i = 0; i < n; ++i) {
sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
- ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR);
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR);
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
- GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay[j]);
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
GGML_ASSERT(false);
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
}
-#elif defined(__POWER9_VECTOR__)
- // TODO: this is temporary because I cannot fit it in the GGML_SIMD pattern like all other architectures without
- // being able to test it. hoping someone with access to a POWER9 machine can help out here.
- const int n32 = (n & ~31);
- for (int i = 0; i < n32; i += 32) {
- // Use vec_xl, not vec_ld, because x is sometimes unaligned!
- vector unsigned short x0 = vec_xl(i * 2 + 0, x);
- vector unsigned short x1 = vec_xl(i * 2 + 16, x);
- vector unsigned short x2 = vec_xl(i * 2 + 32, x);
- vector unsigned short x3 = vec_xl(i * 2 + 48, x);
-
- vector unsigned short y0 = vec_xl(i * 2 + 0, y);
- vector unsigned short y1 = vec_xl(i * 2 + 16, y);
- vector unsigned short y2 = vec_xl(i * 2 + 32, y);
- vector unsigned short y3 = vec_xl(i * 2 + 48, y);
-
- vector float v4 = vec_splats(v);
-
- vector float fx0l = vec_extract_fp32_from_shortl(x0);
- vector float fx0h = vec_extract_fp32_from_shorth(x0);
- vector float fx1l = vec_extract_fp32_from_shortl(x1);
- vector float fx1h = vec_extract_fp32_from_shorth(x1);
- vector float fx2l = vec_extract_fp32_from_shortl(x2);
- vector float fx2h = vec_extract_fp32_from_shorth(x2);
- vector float fx3l = vec_extract_fp32_from_shortl(x3);
- vector float fx3h = vec_extract_fp32_from_shorth(x3);
-
- vector float fy0l = vec_extract_fp32_from_shortl(y0);
- vector float fy0h = vec_extract_fp32_from_shorth(y0);
- vector float fy1l = vec_extract_fp32_from_shortl(y1);
- vector float fy1h = vec_extract_fp32_from_shorth(y1);
- vector float fy2l = vec_extract_fp32_from_shortl(y2);
- vector float fy2h = vec_extract_fp32_from_shorth(y2);
- vector float fy3l = vec_extract_fp32_from_shortl(y3);
- vector float fy3h = vec_extract_fp32_from_shorth(y3);
-
- fy0l = vec_madd(fx0l, v4, fy0l);
- fy0h = vec_madd(fx0h, v4, fy0h);
- fy1l = vec_madd(fx1l, v4, fy1l);
- fy1h = vec_madd(fx1h, v4, fy1h);
- fy2l = vec_madd(fx2l, v4, fy2l);
- fy2h = vec_madd(fx2h, v4, fy2h);
- fy3l = vec_madd(fx3l, v4, fy3l);
- fy3h = vec_madd(fx3h, v4, fy3h);
-
- y0 = vec_pack_to_short_fp32(fy0h, fy0l);
- y1 = vec_pack_to_short_fp32(fy1h, fy1l);
- y2 = vec_pack_to_short_fp32(fy2h, fy2l);
- y3 = vec_pack_to_short_fp32(fy3h, fy3l);
-
- vec_xst(y0, i * 2 + 0, y);
- vec_xst(y1, i * 2 + 16, y);
- vec_xst(y2, i * 2 + 32, y);
- vec_xst(y3, i * 2 + 48, y);
- }
-
- for (int i = n32; i < n; ++i) {
- y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
- }
#else
for (int i = 0; i < n; ++i) {
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
static atomic_int g_state_barrier = 0;
// barrier via spin lock
-inline static void ggml_critical_section_start() {
+inline static void ggml_critical_section_start(void) {
int processing = atomic_fetch_add(&g_state_barrier, 1);
while (processing > 0) {
// TODO: make this somehow automatically executed
// some sort of "sentry" mechanism
-inline static void ggml_critical_section_end() {
+inline static void ggml_critical_section_end(void) {
atomic_fetch_sub(&g_state_barrier, 1);
}
static bool is_first_call = true;
if (is_first_call) {
- // initialize GELU and EXP tables
+ // initialize GELU, EXP and F32 tables
{
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
for (int i = 0; i < (1 << 16); ++i) {
uint16_t ui = i;
memcpy(&ii, &ui, sizeof(ii));
- const float f = GGML_FP16_TO_FP32(ii);
+ const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
table_exp_f16[i] = GGML_FP32_TO_FP16(exp(f));
}
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
g_state = (struct ggml_state) {
- /*.contexts =*/ { 0 },
+ /*.contexts =*/ { { 0 } },
};
for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
#endif
}
+int ggml_cpu_has_sse3(void) {
+#if defined(__SSE3__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_vsx(void) {
+#if defined(__POWER9_VECTOR__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
////////////////////////////////////////////////////////////////////////////////