}
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
-#if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
+#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
const __m256i zero = _mm256_setzero_si256();
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
return _mm256_cvtepi32_ps(summed_pairs);
#include "ggml-impl.h"
#include "ggml-quants.h"
#include "ggml.h"
+#include "sgemm.h"
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#include <unistd.h>
#endif
+#ifndef GGML_USE_LLAMAFILE
+#ifdef __ARM_FEATURE_MATMUL_INT8
+#define GGML_USE_LLAMAFILE 0
+#else
+#define GGML_USE_LLAMAFILE 1
+#endif
+#endif
+
#if defined(_MSC_VER)
// disable "possible loss of data" to avoid hundreds of casts
// we should just be careful :)
}
#endif
+#if GGML_USE_LLAMAFILE
+ if (nb10 == ggml_type_size(src1->type)) {
+ for (int64_t i13 = 0; i13 < ne13; i13++)
+ for (int64_t i12 = 0; i12 < ne12; i12++)
+ if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
+ (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
+ nb01/ggml_type_size(src0->type),
+ (const char *)src1->data + i12*nb12 + i13*nb13,
+ nb11/ggml_type_size(src1->type),
+ (char *)dst->data + i12*nb2 + i13*nb3,
+ nb1/ggml_type_size(dst->type),
+ ith, nth,
+ params->type,
+ src0->type,
+ src1->type,
+ dst->type))
+ goto UseGgmlGemm1;
+ return;
+ }
+UseGgmlGemm1:;
+#endif
+
if (params->type == GGML_TASK_TYPE_INIT) {
if (ith != 0) {
return;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+#if GGML_USE_LLAMAFILE
+ if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) {
+ for (int64_t i13 = 0; i13 < ne13; i13++)
+ for (int64_t i12 = 0; i12 < ne12; i12++)
+ if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
+ (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
+ nb01/ggml_type_size(src0->type),
+ (const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 +
+ nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13),
+ row_size/ggml_type_size(vec_dot_type),
+ (char *)dst->data + i12*nb2 + i13*nb3,
+ nb1/ggml_type_size(dst->type),
+ ith, nth,
+ params->type,
+ src0->type,
+ vec_dot_type,
+ dst->type))
+ goto UseGgmlGemm2;
+ return;
+ }
+UseGgmlGemm2:;
+#endif
+
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = ne1*ne12*ne13; // src1 rows