From: Georgi Gerganov Date: Sat, 22 Apr 2023 09:36:42 +0000 (+0300) Subject: ggml : sync llama.cpp (Q4_3 + CUDA) X-Git-Tag: upstream/0.0.1642~1532 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=9d35c55567de97077f37f9d479095241d2790f64;p=pkg%2Fggml%2Fsources%2Fggml ggml : sync llama.cpp (Q4_3 + CUDA) --- diff --git a/scripts/sync-llama.sh b/scripts/sync-llama.sh new file mode 100755 index 00000000..e5d8c4d3 --- /dev/null +++ b/scripts/sync-llama.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +cp -rpv ../llama.cpp/ggml.c src/ggml.c +cp -rpv ../llama.cpp/ggml-cuda.cu src/ggml-cuda.cu +cp -rpv ../llama.cpp/ggml-cuda.h src/ggml-cuda.h +cp -rpv ../llama.cpp/ggml.h include/ggml/ggml.h diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu new file mode 100644 index 00000000..fa511c1d --- /dev/null +++ b/src/ggml-cuda.cu @@ -0,0 +1,228 @@ +#include +#include +#include +#include +#include "ggml-cuda.h" + +typedef uint16_t ggml_fp16_t; +static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size"); + +#define QK4_0 32 +typedef struct { + float d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + float d; // delta + float m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK4_2 16 +typedef struct { + __half d; // delta + uint8_t qs[QK4_2 / 2]; // nibbles / quants +} block_q4_2; +static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding"); + +#define QK4_3 16 +typedef struct { + __half d; // delta + __half m; // min + uint8_t qs[QK4_3 / 2]; // nibbles / quants +} block_q4_3; +static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); + +static __global__ void dequantize_block_q4_0(const void * vx, float * y) { + const block_q4_0 * x = (const block_q4_0 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + + const uint8_t * pp = x[i].qs; + + for (int l = 0; l < QK4_0; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; + + y[i*QK4_0 + l + 0] = v0; + y[i*QK4_0 + l + 1] = v1; + } +} + +static __global__ void dequantize_block_q4_1(const void * vx, float * y) { + const block_q4_1 * x = (const block_q4_1 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + const float m = x[i].m; + + const uint8_t * pp = x[i].qs; + + for (int l = 0; l < QK4_1; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK4_1 + l + 0] = v0; + y[i*QK4_1 + l + 1] = v1; + } +} + +static __global__ void dequantize_block_q4_2(const void * vx, float * y) { + const block_q4_2 * x = (const block_q4_2 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + + const uint8_t * pp = x[i].qs; + + for (int l = 0; l < QK4_2; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; + + y[i*QK4_2 + l + 0] = v0; + y[i*QK4_2 + l + 1] = v1; + } +} + +static __global__ void dequantize_block_q4_3(const void * vx, float * y) { + const block_q4_3 * x = (const block_q4_3 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + const float m = x[i].m; + + const uint8_t * pp = x[i].qs; + + for (int l = 0; l < QK4_3; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK4_3 + l + 0] = v0; + y[i*QK4_3 + l + 1] = v1; + } +} + +void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_0; + dequantize_block_q4_0<<>>(vx, y); +} + +void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_1; + dequantize_block_q4_1<<>>(vx, y); +} + +void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_2; + dequantize_block_q4_2<<>>(vx, y); +} + +void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_3; + dequantize_block_q4_3<<>>(vx, y); +} + +// buffer pool for cuda +#define MAX_CUDA_BUFFERS 16 + +struct scoped_spin_lock { + std::atomic_flag& lock; + scoped_spin_lock(std::atomic_flag& lock) : lock(lock) { + while (lock.test_and_set(std::memory_order_acquire)) { + ; // spin + } + } + ~scoped_spin_lock() { + lock.clear(std::memory_order_release); + } + scoped_spin_lock(const scoped_spin_lock&) = delete; + scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; +}; + +struct cuda_buffer { + void * ptr = nullptr; + size_t size = 0; +}; + +static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS]; +static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; + +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { + scoped_spin_lock lock(g_cuda_pool_lock); + + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + cuda_buffer& b = g_cuda_buffer_pool[i]; + if (b.size >= size && b.ptr != nullptr) { + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + } + void * ptr; + CUDA_CHECK(cudaMalloc((void **) &ptr, size)); + *actual_size = size; + return ptr; +} + +void ggml_cuda_pool_free(void * ptr, size_t size) { + scoped_spin_lock lock(g_cuda_pool_lock); + + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + cuda_buffer& b = g_cuda_buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); + CUDA_CHECK(cudaFree(ptr)); +} + +cublasHandle_t g_cublasH = NULL; +cudaStream_t g_cudaStream = NULL; + +void ggml_init_cublas(void) { + if (g_cublasH == NULL) { + // create cublas handle, bind a stream + CUBLAS_CHECK(cublasCreate(&g_cublasH)); + + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking)); + + CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream)); + + // configure logging to stdout + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); + } +} diff --git a/src/ggml-cuda.h b/src/ggml-cuda.h new file mode 100644 index 00000000..370bbc75 --- /dev/null +++ b/src/ggml-cuda.h @@ -0,0 +1,41 @@ +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define CUDA_CHECK(err) \ + do { \ + cudaError_t err_ = (err); \ + if (err_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ + cudaGetErrorString(err_)); \ + exit(1); \ + } \ + } while (0) + +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +extern cublasHandle_t g_cublasH; +extern cudaStream_t g_cudaStream; + +void ggml_init_cublas(void); +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); +void ggml_cuda_pool_free(void * ptr, size_t size); + +void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream); + +#ifdef __cplusplus +} +#endif diff --git a/src/ggml.c b/src/ggml.c index 99860215..4277683e 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -148,44 +148,7 @@ inline static void* ggml_aligned_malloc(size_t size) { #elif defined(GGML_USE_OPENBLAS) #include #elif defined(GGML_USE_CUBLAS) -#include -#include #include "ggml-cuda.h" - -#define CUDA_CHECK(err) \ - do { \ - cudaError_t err_ = (err); \ - if (err_ != cudaSuccess) { \ - printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ - cudaGetErrorString(err_)); \ - exit(1); \ - } \ - } while (0) - -#define CUBLAS_CHECK(err) \ - do { \ - cublasStatus_t err_ = (err); \ - if (err_ != CUBLAS_STATUS_SUCCESS) { \ - printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ - exit(1); \ - } \ - } while (0) - -static cublasHandle_t cublasH = NULL; -static cudaStream_t cudaStream = NULL; -static void init_cublas(void) { - if (cublasH == NULL) { - // create cublas handle, bind a stream - CUBLAS_CHECK(cublasCreate(&cublasH)); - - CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); - - CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); - - // configure logging to stdout - // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); - } -} #endif #undef MIN @@ -487,6 +450,32 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi) return bytes; } +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + #if __AVX2__ || __AVX512F__ // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval @@ -507,6 +496,24 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) return bytes; } +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + const __m256i summed_pairs = _mm256_madd_epi16(ones, x); + return _mm256_cvtepi32_ps(summed_pairs); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_float(dot); +} + static inline __m128i packNibbles( __m256i bytes ) { // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh @@ -657,9 +664,11 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong #define QK8_0 32 typedef struct { float d; // delta + float s0; // d * sum(qs[i]) low + float s1; // d * sum(qs[i]) high int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == 3*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); // reference implementation for deterministic creation of model files @@ -1233,9 +1242,9 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int block_q4_2 * restrict y = vy; - //quantize_row_q4_2_reference(x, y, k); + quantize_row_q4_2_reference(x, y, k); // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE) - quantize_row_q4_2_rmse(x, y, k); + //quantize_row_q4_2_rmse(x, y, k); } static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) { @@ -1299,10 +1308,22 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r y[i].d = d; - for (int l = 0; l < QK8_0; ++l) { - const float v = x[i*QK8_0 + l]*id; - y[i].qs[l] = roundf(v); + int sum0 = 0; + int sum1 = 0; + + for (int l = 0; l < QK8_0/2; ++l) { + const float v0 = x[i*QK8_0 + l]*id; + const float v1 = x[i*QK8_0 + QK8_0/2 + l]*id; + + y[i].qs[ l] = roundf(v0); + y[i].qs[QK8_0/2 + l] = roundf(v1); + + sum0 += y[i].qs[ l]; + sum1 += y[i].qs[QK8_0/2 + l]; } + + y[i].s0 = d * sum0; + y[i].s1 = d * sum1; } } @@ -1332,7 +1353,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].d = d; - for (int l = 0; l < 8; l++) { + int32x4_t accv0 = vdupq_n_s32(0); + int32x4_t accv1 = vdupq_n_s32(0); + + // low half + for (int l = 0; l < 4; l++) { const float32x4_t v = vmulq_n_f32(srcv[l], id); const int32x4_t vi = vcvtnq_s32_f32(v); @@ -1340,7 +1365,28 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1); y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2); y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3); + + accv0 = vaddq_s32(accv0, vi); } + + // high half + for (int l = 4; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(srcv[l], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3); + + accv1 = vaddq_s32(accv1, vi); + } + + const int32_t sum0 = vaddvq_s32(accv0); + const int32_t sum1 = vaddvq_s32(accv1); + + y[i].s0 = d * sum0; + y[i].s1 = d * sum1; } #elif defined(__AVX2__) || defined(__AVX__) for (int i = 0; i < nb; i++) { @@ -1388,6 +1434,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int __m256i i3 = _mm256_cvtps_epi32( v3 ); #if defined(__AVX2__) + // Compute the sum of the quants and set y[i].s + //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1)); + y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3)); + // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 @@ -1413,6 +1464,12 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int __m128i ni6 = _mm256_castsi256_si128( i3 ); __m128i ni7 = _mm256_extractf128_si256( i3, 1); + // Compute the sum of the quants and set y[i].s + const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); + const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); + y[i].s0 = d * hsum_i32_4(s0); + y[i].s1 = d * hsum_i32_4(s1); + // Convert int32 to int16 ni0 = _mm_packs_epi32( ni0, ni1 ); ni2 = _mm_packs_epi32( ni2, ni3 ); @@ -2366,20 +2423,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const block_q4_0 * restrict x = vx; const block_q8_0 * restrict y = vy; - float sumf = 0.0; - #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); + float sum8 = 0; + for (int i = 0; i < nb; i += 2) { const block_q4_0 * restrict x0 = &x[i + 0]; const block_q4_0 * restrict x1 = &x[i + 1]; const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; + sum8 += x0->d * (y0->s0 + y0->s1) + x1->d * (y1->s0 + y1->s1); + const uint8x16_t m4b = vdupq_n_u8(0xf); - const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2390,12 +2448,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); @@ -2410,21 +2462,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs); + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); @@ -2436,7 +2488,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2454,32 +2506,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(bx, bx); - - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(by, bx); - - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - - const __m256i ones = _mm256_set1_epi16(1); - __m256i xy_q = _mm256_madd_epi16(ones, dot); - - /* Convert to vectore of 8 int32_t to 8 floats */ - __m256 q = _mm256_cvtepi32_ps( xy_q ); + const __m256 q = mul_sum_i8_pairs_float(bx, by); /* Multiply q with scale and accumulate */ acc = _mm256_fmadd_ps( d, q, acc ); } - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ); + *s = hsum_float_8(acc); #elif defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2518,15 +2551,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); } - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ); + *s = hsum_float_8(acc); #else // scalar + float sumf = 0.0; for (int i = 0; i < nb; i++) { const float d0 = x[i].d; const float d1 = y[i].d; @@ -2548,9 +2576,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * } sumf += d0*d1*sumi; } -#endif - *s = sumf; +#endif } static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -2562,19 +2589,21 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const block_q4_1 * restrict x = vx; const block_q8_0 * restrict y = vy; - float sumf = 0.0; - // TODO: add AVX / WASM SIMD / etc #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); + float summs = 0; + for (int i = 0; i < nb; i += 2) { const block_q4_1 * restrict x0 = &x[i + 0]; const block_q4_1 * restrict x1 = &x[i + 1]; const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; + summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1); + const uint8x16_t m4b = vdupq_n_u8(0xf); const uint8x16_t v0_0 = vld1q_u8(x0->qs); @@ -2586,33 +2615,22 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + // interleave + const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); + const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); + const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h); + const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h); + // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - // interleave - const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h); - const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h); - const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h); - const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h); - - const int16x8_t s0i = vaddq_s16( - vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))), - vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs)))); - - const int16x8_t s1i = vaddq_s16( - vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))), - vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs)))); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d); - #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); @@ -2637,65 +2655,40 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); + float summs = 0; + // Main loop for (int i = 0; i < nb; ++i) { const float * d0 = &x[i].d; const float * d1 = &y[i].d; - const float * m0 = &x[i].m; + + summs += x[i].m * (y[i].s0 + y[i].s1); const __m256 d0v = _mm256_broadcast_ss( d0 ); const __m256 d1v = _mm256_broadcast_ss( d1 ); - const __m256 m0v = _mm256_broadcast_ss( m0 ); // Compute combined scales const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); - const __m256 d1m0 = _mm256_mul_ps( d1v, m0v ); // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes const __m256i bx = bytes_from_nibbles_32(x[i].qs); const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8( bx, bx ); - - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8( by, bx ); - - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16( ax, sy ); - const __m256i ones = _mm256_set1_epi16( 1 ); - const __m256i xy_q = _mm256_madd_epi16( ones, dot ); - - // Convert to vector of 8 int32_t to 8 floats - const __m256 xy = _mm256_cvtepi32_ps( xy_q ); + const __m256 xy = mul_sum_i8_pairs_float(bx, by); // Accumulate d0*d1*x*y acc = _mm256_fmadd_ps( d0d1, xy, acc ); - - // Compute sum of y values - const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); - const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); - const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones ); - const __m256 ysum = _mm256_cvtepi32_ps( ysumi ); - - // Accumulate d1*m0*y - acc = _mm256_fmadd_ps( d1m0, ysum, acc ); } - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ); + *s = hsum_float_8(acc) + summs; #else // scalar + float sumf = 0.0; for (int i = 0; i < nb; i++) { const float d0 = x[i].d; const float m0 = x[i].m; @@ -2717,9 +2710,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * sumf += f0*f2 + f1*f3; } } -#endif - *s = sumf; +#endif } static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -2732,8 +2724,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const block_q4_2 * restrict x = vx; const block_q8_0 * restrict y = vy; - float sumf = 0.0; - #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -2811,7 +2801,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2833,32 +2823,16 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(bx, bx); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(by, bx); - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - - const __m256i ones = _mm256_set1_epi16(1); - __m256i xy_q = _mm256_madd_epi16(ones, dot); - - /* Convert to vectore of 8 int32_t to 8 floats */ - __m256 q = _mm256_cvtepi32_ps(xy_q); + const __m256 q = mul_sum_i8_pairs_float(bx, by); /* Multiply q with scale and accumulate */ acc = _mm256_fmadd_ps(d, q, acc); } - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps(acc, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(acc)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - - sumf = _mm_cvtss_f32(res); + *s = hsum_float_8(acc); #else // scalar + float sumf = 0.0; for (int i = 0; i < nb; i++) { const uint8_t * restrict x0 = x[2*i + 0].qs; const uint8_t * restrict x1 = x[2*i + 1].qs; @@ -2893,9 +2867,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * sumf += (d0 * y[i].d) * sumi_0; sumf += (d1 * y[i].d) * sumi_1; } -#endif - *s = sumf; +#endif } static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -2908,96 +2881,91 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const block_q4_3 * restrict x = vx; const block_q8_0 * restrict y = vy; - float sumf = 0.0; - #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - for (int i = 0; i < nb; i += 2) { + float summs0 = 0.0f; + float summs1 = 0.0f; + + for (int i = 0; i < nb; ++i) { const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0]; const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1]; - const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0]; - const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1]; const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0xf); - const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); - const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); - const float x1_0d = GGML_FP16_TO_FP32(x1_0->d); - const float x1_1d = GGML_FP16_TO_FP32(x1_1->d); - - const float x0_0m = GGML_FP16_TO_FP32(x0_0->m); - const float x0_1m = GGML_FP16_TO_FP32(x0_1->m); - const float x1_0m = GGML_FP16_TO_FP32(x1_0->m); - const float x1_1m = GGML_FP16_TO_FP32(x1_1->m); + summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0; + summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1; const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); - const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs)); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf))); const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); // interleave const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); - const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h); - const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h); // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l))); - const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h))); - - const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l))); - const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h))); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d); + const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); + const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); #if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h)); - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d); #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; i++) { + const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); + const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); + const __m256 dx = _mm256_set_m128(d1, d0); + + const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m)); + const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m)); + const __m256 mx = _mm256_set_m128(m1, m0); + + const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); + const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); + const __m256i bx = _mm256_set_m128i(bx1, bx0); + + const __m256 dy = _mm256_broadcast_ss(&y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by); + const __m256 syf = sum_i16_pairs_float(syi); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf)); + acc = _mm256_fmadd_ps(sxy, dy, acc); + } + + *s = hsum_float_8(acc); #else // scalar + float sumf = 0.0; for (int i = 0; i < nb; i++) { const uint8_t * restrict x0 = x[2*i + 0].qs; const uint8_t * restrict x1 = x[2*i + 1].qs; @@ -3008,9 +2976,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m); - int sy_0 = 0; - int sy_1 = 0; - int sxy_0 = 0; int sxy_1 = 0; @@ -3030,19 +2995,14 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const int y0_1 = y0[2*(j + QK8_0/4) + 0]; const int y1_1 = y0[2*(j + QK8_0/4) + 1]; - sy_0 += y0_0 + y1_0; - sy_1 += y0_1 + y1_1; - sxy_0 += x0_0*y0_0 + x1_0*y1_0; sxy_1 += x0_1*y0_1 + x1_1*y1_1; } - sumf += (d0*sxy_0 + m0*sy_0)*y[i].d; - sumf += (d1*sxy_1 + m1*sy_1)*y[i].d; + sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1; } -#endif - *s = sumf; +#endif } @@ -3720,7 +3680,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { // initialize cuBLAS #if defined(GGML_USE_CUBLAS) - init_cublas(); + ggml_init_cublas(); #endif is_first_call = false; @@ -7566,18 +7526,16 @@ static void ggml_compute_forward_mul_mat_f32( } #if defined(GGML_USE_CUBLAS) - float *d_X = NULL; - float *d_Y = NULL; - float *d_D = NULL; const float alpha = 1.0f; const float beta = 0.0f; const int x_ne = ne01 * ne10; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; - CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); + size_t x_size, y_size, d_size; + float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); #endif for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -7589,19 +7547,19 @@ static void ggml_compute_forward_mul_mat_f32( #if defined(GGML_USE_CUBLAS) // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); - CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); // compute CUBLAS_CHECK( - cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha, d_X, ne00, d_Y, ne10, &beta, d_D, ne01)); // copy data to host - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #else // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, @@ -7613,10 +7571,10 @@ static void ggml_compute_forward_mul_mat_f32( } } #if defined(GGML_USE_CUBLAS) - CUDA_CHECK(cudaStreamSynchronize(cudaStream)); - CUDA_CHECK(cudaFree(d_X)); - CUDA_CHECK(cudaFree(d_Y)); - CUDA_CHECK(cudaFree(d_D)); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); #endif //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); @@ -7766,18 +7724,16 @@ static void ggml_compute_forward_mul_mat_f16_f32( #if defined(GGML_USE_CUBLAS) ggml_fp16_t * const wdata = params->wdata; - float *d_X = NULL; - float *d_Y = NULL; - float *d_D = NULL; const float alpha = 1.0f; const float beta = 0.0f; const int x_ne = ne01 * ne10; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; - CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); + size_t x_size, y_size, d_size; + float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); #else float * const wdata = params->wdata; #endif @@ -7811,12 +7767,12 @@ static void ggml_compute_forward_mul_mat_f16_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream)); - CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); // compute CUBLAS_CHECK( - cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha, d_X, CUDA_R_16F, ne00, d_Y, CUDA_R_16F, ne10, @@ -7825,7 +7781,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( CUBLAS_GEMM_DEFAULT)); // copy data to host - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #else const float * x = wdata; const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); @@ -7843,10 +7799,10 @@ static void ggml_compute_forward_mul_mat_f16_f32( } #if defined(GGML_USE_CUBLAS) - CUDA_CHECK(cudaStreamSynchronize(cudaStream)); - CUDA_CHECK(cudaFree(d_X)); - CUDA_CHECK(cudaFree(d_Y)); - CUDA_CHECK(cudaFree(d_D)); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); #endif /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ @@ -8014,20 +7970,17 @@ static void ggml_compute_forward_mul_mat_q_f32( } #if defined(GGML_USE_CUBLAS) - float *d_X = NULL; - float *d_Y = NULL; - float *d_D = NULL; - float *d_Q = NULL; const float alpha = 1.0f; const float beta = 0.0f; const int x_ne = ne01 * ne10; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; - CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type])); + size_t x_size, y_size, d_size, q_size; + float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); + float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size); void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL; if (type == GGML_TYPE_Q4_0) { @@ -8057,9 +8010,9 @@ static void ggml_compute_forward_mul_mat_q_f32( // copy and dequantize on device CUDA_CHECK( cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02, - GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream)); + GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream)); - dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream); + dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream); CUDA_CHECK(cudaGetLastError()); #else { @@ -8075,18 +8028,18 @@ static void ggml_compute_forward_mul_mat_q_f32( #if defined(GGML_USE_CUBLAS) // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); // compute CUBLAS_CHECK( - cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha, d_X, ne00, d_Y, ne10, &beta, d_D, ne01)); // copy data to host - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #else // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, @@ -8099,11 +8052,11 @@ static void ggml_compute_forward_mul_mat_q_f32( } #if defined(GGML_USE_CUBLAS) - CUDA_CHECK(cudaStreamSynchronize(cudaStream)); - CUDA_CHECK(cudaFree(d_X)); - CUDA_CHECK(cudaFree(d_Y)); - CUDA_CHECK(cudaFree(d_D)); - CUDA_CHECK(cudaFree(d_Q)); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); + ggml_cuda_pool_free(d_Q, q_size); #endif //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);