From: Georgi Gerganov Date: Sat, 29 Apr 2023 07:03:59 +0000 (+0300) Subject: ggml : remove Q4_3 X-Git-Tag: upstream/0.0.1642~1510 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=bddea3c5ddad5e1b0ecd878f33d3c0bcc3039c74;p=pkg%2Fggml%2Fsources%2Fggml ggml : remove Q4_3 --- diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 540901f1..38ae9a6e 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -221,7 +221,7 @@ extern "C" { GGML_TYPE_Q4_0 = 2, GGML_TYPE_Q4_1 = 3, GGML_TYPE_Q4_2 = 4, - GGML_TYPE_Q4_3 = 5, + // GGML_TYPE_Q4_3 (5) support has been removed GGML_TYPE_Q5_0 = 6, GGML_TYPE_Q5_1 = 7, GGML_TYPE_Q8_0 = 8, @@ -843,7 +843,6 @@ extern "C" { GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist); - GGML_API size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index b1bd29b1..5a2701cf 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -29,14 +29,6 @@ typedef struct { } block_q4_2; static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding"); -#define QK4_3 16 -typedef struct { - __half d; // delta - __half m; // min - uint8_t qs[QK4_3 / 2]; // nibbles / quants -} block_q4_3; -static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); - #define QK5_0 32 typedef struct { __half d; // delta @@ -131,30 +123,6 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y) { } } -static __global__ void dequantize_block_q4_3(const void * vx, float * y) { - const block_q4_3 * x = (const block_q4_3 *) vx; - - const int i = blockIdx.x; - - const float d = x[i].d; - const float m = x[i].m; - - const uint8_t * pp = x[i].qs; - - for (int l = 0; l < QK4_3; l += 2) { - const uint8_t vi = pp[l/2]; - - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; - - const float v0 = vi0*d + m; - const float v1 = vi1*d + m; - - y[i*QK4_3 + l + 0] = v0; - y[i*QK4_3 + l + 1] = v1; - } -} - static __global__ void dequantize_block_q5_0(const void * vx, float * y) { const block_q5_0 * x = (const block_q5_0 *) vx; @@ -244,11 +212,6 @@ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t st dequantize_block_q4_2<<>>(vx, y); } -void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_3; - dequantize_block_q4_3<<>>(vx, y); -} - void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_0; dequantize_block_q5_0<<>>(vx, y); @@ -264,6 +227,25 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st dequantize_block_q8_0<<>>(vx, y); } +dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_row_q4_0_cuda; + case GGML_TYPE_Q4_1: + return dequantize_row_q4_1_cuda; + case GGML_TYPE_Q4_2: + return dequantize_row_q4_2_cuda; + case GGML_TYPE_Q5_0: + return dequantize_row_q5_0_cuda; + case GGML_TYPE_Q5_1: + return dequantize_row_q5_1_cuda; + case GGML_TYPE_Q8_0: + return dequantize_row_q8_0_cuda; + default: + return nullptr; + } +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 16 @@ -323,19 +305,61 @@ void ggml_cuda_pool_free(void * ptr, size_t size) { CUDA_CHECK(cudaFree(ptr)); } -cublasHandle_t g_cublasH = NULL; -cudaStream_t g_cudaStream = NULL; +cublasHandle_t g_cublasH = nullptr; +cudaStream_t g_cudaStream = nullptr; +cudaStream_t g_cudaStream2 = nullptr; +cudaEvent_t g_cudaEvent = nullptr; -void ggml_init_cublas(void) { - if (g_cublasH == NULL) { +void ggml_init_cublas() { + if (g_cublasH == nullptr) { // create cublas handle, bind a stream CUBLAS_CHECK(cublasCreate(&g_cublasH)); - CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking)); - CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream)); + // create additional stream and event for synchronization + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking)); + CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming)); + // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); } } + +cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) { + const uint64_t ne0 = src->ne[0]; + const uint64_t ne1 = src->ne[1]; + const uint64_t nb0 = src->nb[0]; + const uint64_t nb1 = src->nb[1]; + const uint64_t nb2 = src->nb[2]; + const uint64_t nb3 = src->nb[3]; + const enum ggml_type type = src->type; + const size_t ts = ggml_type_size(type); + const size_t bs = ggml_blck_size(type); + + const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3); + if (nb0 == ts && nb1 == ts*ne0/bs) { + return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream); + } else if (nb0 == ts) { + return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + const void * rx = (const void *) ((const char *) x + i1*nb1); + void * rd = (void *) ((char *) dst + i1*ts*ne0/bs); + // pretend the row is a matrix with cols=1 + cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream); + if (r != cudaSuccess) return r; + } + return cudaSuccess; + } +} + +void * ggml_cuda_host_malloc(size_t size) { + void * ptr; + CUDA_CHECK(cudaMallocHost((void **) &ptr, size)); + return ptr; +} + +void ggml_cuda_host_free(void * ptr) { + CUDA_CHECK(cudaFreeHost(ptr)); +} diff --git a/src/ggml-cuda.h b/src/ggml-cuda.h index ed9b4418..36782d9e 100644 --- a/src/ggml-cuda.h +++ b/src/ggml-cuda.h @@ -1,5 +1,6 @@ #include #include +#include "ggml.h" #ifdef __cplusplus extern "C" { @@ -25,20 +26,29 @@ extern "C" { } while (0) extern cublasHandle_t g_cublasH; -extern cudaStream_t g_cudaStream; +extern cudaStream_t g_cudaStream; +extern cudaStream_t g_cudaStream2; +extern cudaEvent_t g_cudaEvent; void ggml_init_cublas(void); +void * ggml_cuda_host_malloc(size_t size); +void ggml_cuda_host_free(void * ptr); + void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); void ggml_cuda_pool_free(void * ptr, size_t size); void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); -void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); +cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream); + +typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); +dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type); + #ifdef __cplusplus } #endif diff --git a/src/ggml.c b/src/ggml.c index 53796bd9..64ecd086 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -694,14 +694,6 @@ typedef struct { } block_q4_2; static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding"); -#define QK4_3 16 -typedef struct { - ggml_fp16_t d; // delta - ggml_fp16_t m; // min - uint8_t qs[QK4_3 / 2]; // nibbles / quants -} block_q4_3; -static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); - #define QK5_0 32 typedef struct { ggml_fp16_t d; // delta @@ -1291,49 +1283,6 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int quantize_row_q4_2_reference(x, y, k); } -static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) { - assert(k % QK4_3 == 0); - const int nb = k / QK4_3; - - for (int i = 0; i < nb; i++) { - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int l = 0; l < QK4_3; l++) { - const float v = x[i*QK4_3 + l]; - if (v < min) min = v; - if (v > max) max = v; - } - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - y[i].m = GGML_FP32_TO_FP16(min); - - for (int l = 0; l < QK4_3; l += 2) { - const float v0 = (x[i*QK4_3 + l + 0] - min)*id; - const float v1 = (x[i*QK4_3 + l + 1] - min)*id; - - const uint8_t vi0 = (int) (v0 + 0.5f); - const uint8_t vi1 = (int) (v1 + 0.5f); - - assert(vi0 < 16); - assert(vi1 < 16); - - y[i].qs[l/2] = vi0 | (vi1 << 4); - } - } -} - -static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) { - assert(k % QK4_3 == 0); - - block_q4_3 * restrict y = vy; - - quantize_row_q4_3_reference(x, y, k); -} - static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { assert(k % QK5_0 == 0); const int nb = k / QK5_0; @@ -1917,36 +1866,6 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in } } -static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) { - assert(k % QK4_3 == 0); - const int nb = k / QK4_3; - - const block_q4_3 * restrict x = vx; - - for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - const float m = GGML_FP16_TO_FP32(x[i].m); - - const uint8_t * restrict pp = x[i].qs; - - for (int l = 0; l < QK4_3; l += 2) { - const uint8_t vi = pp[l/2]; - - const int8_t vi0 = vi & 0x0F; - const int8_t vi1 = vi >> 4; - - const float v0 = vi0*d + m; - const float v1 = vi1*d + m; - - y[i*QK4_3 + l + 0] = v0; - y[i*QK4_3 + l + 1] = v1; - - assert(!isnan(y[i*QK4_3 + l + 0])); - assert(!isnan(y[i*QK4_3 + l + 1])); - } - } -} - static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) { assert(k % QK5_0 == 0); const int nb = k / QK5_0; @@ -2040,7 +1959,6 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); @@ -2070,14 +1988,6 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .vec_dot_q = ggml_vec_dot_q4_2_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, }, - [GGML_TYPE_Q4_3] = { - .dequantize_row_q = dequantize_row_q4_3, - .quantize_row_q = quantize_row_q4_3, - .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, - .quantize_row_q_dot = quantize_row_q8_1, - .vec_dot_q = ggml_vec_dot_q4_3_q8_1, - .vec_dot_type = GGML_TYPE_Q8_1, - }, [GGML_TYPE_Q5_0] = { .dequantize_row_q = dequantize_row_q5_0, .quantize_row_q = quantize_row_q5_0, @@ -3171,136 +3081,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * #endif } -static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK8_1; - - assert(n % QK8_1 == 0); - assert(nb % 2 == 0); - assert(QK8_1 == 2*QK4_3); - - const block_q4_3 * restrict x = vx; - const block_q8_1 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs0 = 0.0f; - float summs1 = 0.0f; - - for (int i = 0; i < nb; ++i) { - const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0]; - const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1]; - - const block_q8_1 * restrict y0 = &y[i + 0]; - - summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0; - summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1; - - const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F))); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - - // interleave - const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); - const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - - const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); - const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); - -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d); -#endif - } - - *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1; -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - float summs = 0.0f; - - // Main loop - for (int i = 0; i < nb; i++) { - const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); - const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); - const __m256 dx = _mm256_set_m128(d1, d0); - - summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0 - + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1; - - const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); - const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); - const __m256i bx = _mm256_set_m128i(bx1, bx0); - - const __m256 dy = _mm256_broadcast_ss(&y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx, by); - - acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); - } - - *s = hsum_float_8(acc) + summs; -#else - // scalar - float sumf = 0.0; - for (int i = 0; i < nb; i++) { - const uint8_t * restrict x0 = x[2*i + 0].qs; - const uint8_t * restrict x1 = x[2*i + 1].qs; - const int8_t * restrict y0 = y[i].qs; - - const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d); - const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m); - const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); - const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m); - - int sxy_0 = 0; - int sxy_1 = 0; - - for (int j = 0; j < QK8_1/4; j++) { - const uint8_t v0 = x0[j]; - const uint8_t v1 = x1[j]; - - const int x0_0 = v0 & 0x0F; - const int x1_0 = v0 >> 4; - - const int x0_1 = v1 & 0x0F; - const int x1_1 = v1 >> 4; - - const int y0_0 = y0[2*j + 0]; - const int y1_0 = y0[2*j + 1]; - - const int y0_1 = y0[2*(j + QK8_1/4) + 0]; - const int y1_1 = y0[2*(j + QK8_1/4) + 1]; - - sxy_0 += x0_0*y0_0 + x1_0*y1_0; - sxy_1 += x0_1*y0_1 + x1_1*y1_1; - } - - sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1; - } - *s = sumf; -#endif -} - static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const int nb = n / QK8_0; @@ -3925,7 +3705,6 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0] = QK4_0, [GGML_TYPE_Q4_1] = QK4_1, [GGML_TYPE_Q4_2] = QK4_2, - [GGML_TYPE_Q4_3] = QK4_3, [GGML_TYPE_Q5_0] = QK5_0, [GGML_TYPE_Q5_1] = QK5_1, [GGML_TYPE_Q8_0] = QK8_0, @@ -3942,7 +3721,6 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0] = sizeof(block_q4_0), [GGML_TYPE_Q4_1] = sizeof(block_q4_1), [GGML_TYPE_Q4_2] = sizeof(block_q4_2), - [GGML_TYPE_Q4_3] = sizeof(block_q4_3), [GGML_TYPE_Q5_0] = sizeof(block_q5_0), [GGML_TYPE_Q5_1] = sizeof(block_q5_1), [GGML_TYPE_Q8_0] = sizeof(block_q8_0), @@ -3960,7 +3738,6 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0] = "q4_0", [GGML_TYPE_Q4_1] = "q4_1", [GGML_TYPE_Q4_2] = "q4_2", - [GGML_TYPE_Q4_3] = "q4_3", [GGML_TYPE_Q5_0] = "q5_0", [GGML_TYPE_Q5_1] = "q5_1", [GGML_TYPE_Q8_0] = "q8_0", @@ -3977,7 +3754,6 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0] = true, [GGML_TYPE_Q4_1] = true, [GGML_TYPE_Q4_2] = true, - [GGML_TYPE_Q4_3] = true, [GGML_TYPE_Q5_0] = true, [GGML_TYPE_Q5_1] = true, [GGML_TYPE_Q8_0] = true, @@ -7230,7 +7006,6 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: - case GGML_TYPE_Q4_3: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: @@ -8155,8 +7930,12 @@ static bool ggml_compute_forward_mul_mat_use_blas( const int64_t ne1 = dst->ne[1]; // TODO: find the optimal values for these - if (ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) { + if ( +#if !defined(GGML_USE_CUBLAS) + ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && +#endif + ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) { /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ return true; @@ -8254,7 +8033,7 @@ static void ggml_compute_forward_mul_mat_f32( #if defined(GGML_USE_CUBLAS) const float alpha = 1.0f; const float beta = 0.0f; - const int x_ne = ne01 * ne10; + const int x_ne = ne01 * ne00; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; @@ -8266,15 +8045,16 @@ static void ggml_compute_forward_mul_mat_f32( for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { +#if !defined(GGML_USE_CUBLAS) const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); - +#endif float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); #if defined(GGML_USE_CUBLAS) // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream)); - CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream)); + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream)); // compute CUBLAS_CHECK( @@ -8455,25 +8235,27 @@ static void ggml_compute_forward_mul_mat_f16_f32( } #if defined(GGML_USE_CUBLAS) - ggml_fp16_t * const wdata = params->wdata; - const float alpha = 1.0f; const float beta = 0.0f; - const int x_ne = ne01 * ne10; + const int x_ne = ne01 * ne00; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; size_t x_size, y_size, d_size; - float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); - float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); - float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); + ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); #else float * const wdata = params->wdata; #endif for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { #if defined(GGML_USE_CUBLAS) + // copy src0 while converting src1 + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream)); + // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16 + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02); { size_t id = 0; for (int64_t i01 = 0; i01 < ne11; ++i01) { @@ -8494,13 +8276,10 @@ static void ggml_compute_forward_mul_mat_f16_f32( #endif #if defined(GGML_USE_CUBLAS) - const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03); const ggml_fp16_t * y = (ggml_fp16_t *) wdata; - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream)); CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); // compute @@ -8719,42 +8498,19 @@ static void ggml_compute_forward_mul_mat_q_f32( #if defined(GGML_USE_CUBLAS) const float alpha = 1.0f; const float beta = 0.0f; - const int x_ne = ne01 * ne10; + const int x_ne = ne01 * ne00; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; size_t x_size, y_size, d_size, q_size; - float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); - float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); - float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); - float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size); + float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); + void * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size); - void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL; - if (type == GGML_TYPE_Q4_0) { - dequantize_row_q_cuda = dequantize_row_q4_0_cuda; - } - else if (type == GGML_TYPE_Q4_1) { - dequantize_row_q_cuda = dequantize_row_q4_1_cuda; - } - else if (type == GGML_TYPE_Q4_2) { - dequantize_row_q_cuda = dequantize_row_q4_2_cuda; - } - else if (type == GGML_TYPE_Q4_3) { - dequantize_row_q_cuda = dequantize_row_q4_3_cuda; - } - else if (type == GGML_TYPE_Q5_0) { - dequantize_row_q_cuda = dequantize_row_q5_0_cuda; - } - else if (type == GGML_TYPE_Q5_1) { - dequantize_row_q_cuda = dequantize_row_q5_1_cuda; - } - else if (type == GGML_TYPE_Q8_0) { - dequantize_row_q_cuda = dequantize_row_q8_0_cuda; - } - else { - GGML_ASSERT(false); - } -#elif !defined(GGML_USE_CLBLAST) + const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type); + GGML_ASSERT(dequantize_row_q_cuda != NULL); +#else float * const wdata = params->wdata; dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; #endif @@ -8767,12 +8523,11 @@ static void ggml_compute_forward_mul_mat_q_f32( #if defined(GGML_USE_CUBLAS) // copy and dequantize on device - CUDA_CHECK( - cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02, - GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream)); + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2)); - dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream); + dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2); CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2)); #elif defined(GGML_USE_CLBLAST) const void* x = (char *) src0->data + i03*nb03 + i02*nb02; #else @@ -8786,10 +8541,12 @@ static void ggml_compute_forward_mul_mat_q_f32( const float * x = wdata; #endif - #if defined(GGML_USE_CUBLAS) // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream)); + + // wait for dequantization + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0)); // compute CUBLAS_CHECK( @@ -8914,7 +8671,6 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: - case GGML_TYPE_Q4_3: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: @@ -9146,7 +8902,6 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: - case GGML_TYPE_Q4_3: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: @@ -9472,7 +9227,6 @@ static void ggml_compute_forward_alibi( case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: - case GGML_TYPE_Q4_3: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: @@ -11817,7 +11571,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning - cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*MAX(ggml_nelements(node->src1), ggml_nelements(node->src0)); //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); //printf("cur = %zu\n", cur); @@ -11829,6 +11583,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) #endif } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; + } +#endif } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) { #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { @@ -13088,29 +12847,6 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * return (n/QK4_2*sizeof(block_q4_2)); } -size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK4_3 == 0); - const int nb = k / QK4_3; - - for (int j = 0; j < n; j += k) { - block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3; - - quantize_row_q4_3_reference(src + j, y, k); - - for (int i = 0; i < nb; i++) { - for (int l = 0; l < QK4_3; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0x0F; - const uint8_t vi1 = y[i].qs[l/2] >> 4; - - hist[vi0]++; - hist[vi1]++; - } - } - } - - return (n/QK4_3*sizeof(block_q4_3)); -} - size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) { assert(k % QK5_0 == 0); const int nb = k / QK5_0; @@ -13213,12 +12949,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q4_2 * block = (block_q4_2*)dst + start / QK4_2; result = ggml_quantize_q4_2(src + start, block, n, n, hist); } break; - case GGML_TYPE_Q4_3: - { - GGML_ASSERT(start % QK4_3 == 0); - block_q4_3 * block = (block_q4_3*)dst + start / QK4_3; - result = ggml_quantize_q4_3(src + start, block, n, n, hist); - } break; case GGML_TYPE_Q5_0: { GGML_ASSERT(start % QK5_0 == 0);