]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : remove Q4_3
authorGeorgi Gerganov <redacted>
Sat, 29 Apr 2023 07:03:59 +0000 (10:03 +0300)
committerGeorgi Gerganov <redacted>
Sat, 29 Apr 2023 07:03:59 +0000 (10:03 +0300)
include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-cuda.h
src/ggml.c

index 540901f15a7f1e58e7a830cb5526efb6e1d124b3..38ae9a6eeeb71b29cb40a8e318ef3319334c3b0c 100644 (file)
@@ -221,7 +221,7 @@ extern "C" {
         GGML_TYPE_Q4_0 = 2,
         GGML_TYPE_Q4_1 = 3,
         GGML_TYPE_Q4_2 = 4,
-        GGML_TYPE_Q4_3 = 5,
+        // GGML_TYPE_Q4_3 (5) support has been removed
         GGML_TYPE_Q5_0 = 6,
         GGML_TYPE_Q5_1 = 7,
         GGML_TYPE_Q8_0 = 8,
@@ -843,7 +843,6 @@ extern "C" {
     GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist);
-    GGML_API size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
index b1bd29b100d811dc09db743c4c6537e2faf82009..5a2701cfeef68696730b7c3c067fb41a60ad221d 100644 (file)
@@ -29,14 +29,6 @@ typedef struct {
 } block_q4_2;
 static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
 
-#define QK4_3 16
-typedef struct {
-    __half  d;              // delta
-    __half  m;              // min
-    uint8_t qs[QK4_3 / 2];  // nibbles / quants
-} block_q4_3;
-static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
-
 #define QK5_0 32
 typedef struct {
     __half d;               // delta
@@ -131,30 +123,6 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
     }
 }
 
-static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
-    const block_q4_3 * x = (const block_q4_3 *) vx;
-
-    const int i = blockIdx.x;
-
-    const float d = x[i].d;
-    const float m = x[i].m;
-
-    const uint8_t * pp = x[i].qs;
-
-    for (int l = 0; l < QK4_3; l += 2) {
-        const uint8_t vi = pp[l/2];
-
-        const int8_t vi0 = vi & 0xf;
-        const int8_t vi1 = vi >> 4;
-
-        const float v0 = vi0*d + m;
-        const float v1 = vi1*d + m;
-
-        y[i*QK4_3 + l + 0] = v0;
-        y[i*QK4_3 + l + 1] = v1;
-    }
-}
-
 static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
     const block_q5_0 * x = (const block_q5_0 *) vx;
 
@@ -244,11 +212,6 @@ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t st
     dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
 }
 
-void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-    const int nb = k / QK4_3;
-    dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
-}
-
 void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
     const int nb = k / QK5_0;
     dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
@@ -264,6 +227,25 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st
     dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
 }
 
+dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_row_q4_0_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_row_q4_1_cuda;
+        case GGML_TYPE_Q4_2:
+            return dequantize_row_q4_2_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_row_q5_0_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_row_q5_1_cuda;
+        case GGML_TYPE_Q8_0:
+            return dequantize_row_q8_0_cuda;
+        default:
+            return nullptr;
+    }
+}
+
 // buffer pool for cuda
 #define MAX_CUDA_BUFFERS 16
 
@@ -323,19 +305,61 @@ void ggml_cuda_pool_free(void * ptr, size_t size) {
     CUDA_CHECK(cudaFree(ptr));
 }
 
-cublasHandle_t g_cublasH = NULL;
-cudaStream_t g_cudaStream = NULL;
+cublasHandle_t g_cublasH = nullptr;
+cudaStream_t g_cudaStream = nullptr;
+cudaStream_t g_cudaStream2 = nullptr;
+cudaEvent_t g_cudaEvent = nullptr;
 
-void ggml_init_cublas(void) {
-    if (g_cublasH == NULL) {
+void ggml_init_cublas() {
+    if (g_cublasH == nullptr) {
         // create cublas handle, bind a stream
         CUBLAS_CHECK(cublasCreate(&g_cublasH));
-
         CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
-
         CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
 
+        // create additional stream and event for synchronization
+        CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
+        CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
+
         // configure logging to stdout
         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
     }
 }
+
+cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
+    const uint64_t ne0 = src->ne[0];
+    const uint64_t ne1 = src->ne[1];
+    const uint64_t nb0 = src->nb[0];
+    const uint64_t nb1 = src->nb[1];
+    const uint64_t nb2 = src->nb[2];
+    const uint64_t nb3 = src->nb[3];
+    const enum ggml_type type = src->type;
+    const size_t ts = ggml_type_size(type);
+    const size_t bs = ggml_blck_size(type);
+
+    const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
+    if (nb0 == ts && nb1 == ts*ne0/bs) {
+        return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
+    } else if (nb0 == ts) {
+        return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
+    } else {
+        for (uint64_t i1 = 0; i1 < ne1; i1++) {
+            const void * rx = (const void *) ((const char *) x + i1*nb1);
+            void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
+            // pretend the row is a matrix with cols=1
+            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
+            if (r != cudaSuccess) return r;
+        }
+        return cudaSuccess;
+    }
+}
+
+void * ggml_cuda_host_malloc(size_t size) {
+    void * ptr;
+    CUDA_CHECK(cudaMallocHost((void **) &ptr, size));
+    return ptr;
+}
+
+void ggml_cuda_host_free(void * ptr) {
+    CUDA_CHECK(cudaFreeHost(ptr));
+}
index ed9b44184bf566ecb0036f3212f9681e7a3be099..36782d9e796b7ed873a036cc3cfb542ead3e1116 100644 (file)
@@ -1,5 +1,6 @@
 #include <cublas_v2.h>
 #include <cuda_runtime.h>
+#include "ggml.h"
 
 #ifdef  __cplusplus
 extern "C" {
@@ -25,20 +26,29 @@ extern "C" {
     } while (0)
 
 extern cublasHandle_t g_cublasH;
-extern cudaStream_t   g_cudaStream;
+extern cudaStream_t g_cudaStream;
+extern cudaStream_t g_cudaStream2;
+extern cudaEvent_t g_cudaEvent;
 
 void   ggml_init_cublas(void);
+void * ggml_cuda_host_malloc(size_t size);
+void   ggml_cuda_host_free(void * ptr);
+
 void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
 void   ggml_cuda_pool_free(void * ptr, size_t size);
 
 void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
-void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 
+cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
+
+typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
+dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type);
+
 #ifdef  __cplusplus
 }
 #endif
index 53796bd97d3abd68c36e7a38d23cce4ec5ac0d5d..64ecd0867e4fb0571141140f558c47656388e8a1 100644 (file)
@@ -694,14 +694,6 @@ typedef struct {
 } block_q4_2;
 static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
 
-#define QK4_3 16
-typedef struct {
-    ggml_fp16_t d;         // delta
-    ggml_fp16_t m;         // min
-    uint8_t qs[QK4_3 / 2]; // nibbles / quants
-} block_q4_3;
-static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
-
 #define QK5_0 32
 typedef struct {
     ggml_fp16_t d;         // delta
@@ -1291,49 +1283,6 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int
     quantize_row_q4_2_reference(x, y, k);
 }
 
-static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
-    assert(k % QK4_3 == 0);
-    const int nb = k / QK4_3;
-
-    for (int i = 0; i < nb; i++) {
-        float min = FLT_MAX;
-        float max = -FLT_MAX;
-
-        for (int l = 0; l < QK4_3; l++) {
-            const float v = x[i*QK4_3 + l];
-            if (v < min) min = v;
-            if (v > max) max = v;
-        }
-
-        const float d = (max - min) / ((1 << 4) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        y[i].d = GGML_FP32_TO_FP16(d);
-        y[i].m = GGML_FP32_TO_FP16(min);
-
-        for (int l = 0; l < QK4_3; l += 2) {
-            const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
-            const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
-
-            const uint8_t vi0 = (int) (v0 + 0.5f);
-            const uint8_t vi1 = (int) (v1 + 0.5f);
-
-            assert(vi0 < 16);
-            assert(vi1 < 16);
-
-            y[i].qs[l/2] = vi0 | (vi1 << 4);
-        }
-    }
-}
-
-static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK4_3 == 0);
-
-    block_q4_3 * restrict y = vy;
-
-    quantize_row_q4_3_reference(x, y, k);
-}
-
 static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
     assert(k % QK5_0 == 0);
     const int nb = k / QK5_0;
@@ -1917,36 +1866,6 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
     }
 }
 
-static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
-    assert(k % QK4_3 == 0);
-    const int nb = k / QK4_3;
-
-    const block_q4_3 * restrict x = vx;
-
-    for (int i = 0; i < nb; i++) {
-        const float d = GGML_FP16_TO_FP32(x[i].d);
-        const float m = GGML_FP16_TO_FP32(x[i].m);
-
-        const uint8_t * restrict pp = x[i].qs;
-
-        for (int l = 0; l < QK4_3; l += 2) {
-            const uint8_t vi = pp[l/2];
-
-            const int8_t vi0 = vi & 0x0F;
-            const int8_t vi1 = vi >> 4;
-
-            const float v0 = vi0*d + m;
-            const float v1 = vi1*d + m;
-
-            y[i*QK4_3 + l + 0] = v0;
-            y[i*QK4_3 + l + 1] = v1;
-
-            assert(!isnan(y[i*QK4_3 + l + 0]));
-            assert(!isnan(y[i*QK4_3 + l + 1]));
-        }
-    }
-}
-
 static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
     assert(k % QK5_0 == 0);
     const int nb = k / QK5_0;
@@ -2040,7 +1959,6 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in
 static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
@@ -2070,14 +1988,6 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
         .vec_dot_q                = ggml_vec_dot_q4_2_q8_0,
         .vec_dot_type             = GGML_TYPE_Q8_0,
     },
-    [GGML_TYPE_Q4_3] = {
-        .dequantize_row_q         = dequantize_row_q4_3,
-        .quantize_row_q           = quantize_row_q4_3,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference,
-        .quantize_row_q_dot       = quantize_row_q8_1,
-        .vec_dot_q                = ggml_vec_dot_q4_3_q8_1,
-        .vec_dot_type             = GGML_TYPE_Q8_1,
-    },
     [GGML_TYPE_Q5_0] = {
         .dequantize_row_q         = dequantize_row_q5_0,
         .quantize_row_q           = quantize_row_q5_0,
@@ -3171,136 +3081,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
 #endif
 }
 
-static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int nb = n / QK8_1;
-
-    assert(n % QK8_1 == 0);
-    assert(nb % 2 == 0);
-    assert(QK8_1 == 2*QK4_3);
-
-    const block_q4_3 * restrict x = vx;
-    const block_q8_1 * restrict y = vy;
-
-#if defined(__ARM_NEON)
-    float32x4_t sumv0 = vdupq_n_f32(0.0f);
-    float32x4_t sumv1 = vdupq_n_f32(0.0f);
-
-    float summs0 = 0.0f;
-    float summs1 = 0.0f;
-
-    for (int i = 0; i < nb; ++i) {
-        const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
-        const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
-
-        const block_q8_1 * restrict y0 = &y[i + 0];
-
-        summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
-        summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
-
-        const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
-
-        // 4-bit -> 8-bit
-        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, vdupq_n_u8(0x0F)));
-        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
-
-        // interleave
-        const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
-        const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
-
-        // load y
-        const int8x16_t v1_0l = vld1q_s8(y0->qs);
-        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
-
-        const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
-        const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
-
-#if defined(__ARM_FEATURE_DOTPROD)
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
-#else
-        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
-        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
-        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
-        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
-
-        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
-        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
-
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
-#endif
-    }
-
-    *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
-#elif defined(__AVX2__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-    float summs = 0.0f;
-
-    // Main loop
-    for (int i = 0; i < nb; i++) {
-        const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
-        const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
-        const __m256 dx = _mm256_set_m128(d1, d0);
-
-        summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0
-               + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1;
-
-        const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
-        const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
-        const __m256i bx = _mm256_set_m128i(bx1, bx0);
-
-        const __m256 dy = _mm256_broadcast_ss(&y[i].d);
-        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
-
-        const __m256 q = mul_sum_i8_pairs_float(bx, by);
-
-        acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
-    }
-
-    *s = hsum_float_8(acc) + summs;
-#else
-    // scalar
-    float sumf = 0.0;
-    for (int i = 0; i < nb; i++) {
-        const uint8_t * restrict x0 = x[2*i + 0].qs;
-        const uint8_t * restrict x1 = x[2*i + 1].qs;
-        const  int8_t * restrict y0 = y[i].qs;
-
-        const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
-        const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
-        const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
-        const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
-
-        int sxy_0 = 0;
-        int sxy_1 = 0;
-
-        for (int j = 0; j < QK8_1/4; j++) {
-            const uint8_t v0 = x0[j];
-            const uint8_t v1 = x1[j];
-
-            const int x0_0 = v0 & 0x0F;
-            const int x1_0 = v0 >> 4;
-
-            const int x0_1 = v1 & 0x0F;
-            const int x1_1 = v1 >> 4;
-
-            const int y0_0 = y0[2*j + 0];
-            const int y1_0 = y0[2*j + 1];
-
-            const int y0_1 = y0[2*(j + QK8_1/4) + 0];
-            const int y1_1 = y0[2*(j + QK8_1/4) + 1];
-
-            sxy_0 += x0_0*y0_0 + x1_0*y1_0;
-            sxy_1 += x0_1*y0_1 + x1_1*y1_1;
-        }
-
-        sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
-    }
-    *s = sumf;
-#endif
-}
-
 static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     const int nb = n / QK8_0;
 
@@ -3925,7 +3705,6 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = QK4_0,
     [GGML_TYPE_Q4_1] = QK4_1,
     [GGML_TYPE_Q4_2] = QK4_2,
-    [GGML_TYPE_Q4_3] = QK4_3,
     [GGML_TYPE_Q5_0] = QK5_0,
     [GGML_TYPE_Q5_1] = QK5_1,
     [GGML_TYPE_Q8_0] = QK8_0,
@@ -3942,7 +3721,6 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
     [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
     [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
-    [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
     [GGML_TYPE_Q5_0] = sizeof(block_q5_0),
     [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
     [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
@@ -3960,7 +3738,6 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = "q4_0",
     [GGML_TYPE_Q4_1] = "q4_1",
     [GGML_TYPE_Q4_2] = "q4_2",
-    [GGML_TYPE_Q4_3] = "q4_3",
     [GGML_TYPE_Q5_0] = "q5_0",
     [GGML_TYPE_Q5_1] = "q5_1",
     [GGML_TYPE_Q8_0] = "q8_0",
@@ -3977,7 +3754,6 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = true,
     [GGML_TYPE_Q4_1] = true,
     [GGML_TYPE_Q4_2] = true,
-    [GGML_TYPE_Q4_3] = true,
     [GGML_TYPE_Q5_0] = true,
     [GGML_TYPE_Q5_1] = true,
     [GGML_TYPE_Q8_0] = true,
@@ -7230,7 +7006,6 @@ static void ggml_compute_forward_add(
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
-        case GGML_TYPE_Q4_3:
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
@@ -8155,8 +7930,12 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     const int64_t ne1 = dst->ne[1];
 
     // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) &&
-        ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
+    if (
+#if !defined(GGML_USE_CUBLAS)
+        ggml_is_contiguous(src0) &&
+        ggml_is_contiguous(src1) &&
+#endif
+        ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
 
         /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
         return true;
@@ -8254,7 +8033,7 @@ static void ggml_compute_forward_mul_mat_f32(
 #if defined(GGML_USE_CUBLAS)
         const float alpha = 1.0f;
         const float beta = 0.0f;
-        const int x_ne = ne01 * ne10;
+        const int x_ne = ne01 * ne00;
         const int y_ne = ne11 * ne10;
         const int d_ne = ne11 * ne01;
 
@@ -8266,15 +8045,16 @@ static void ggml_compute_forward_mul_mat_f32(
 
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
+#if !defined(GGML_USE_CUBLAS)
                 const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
-
+#endif
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
 #if defined(GGML_USE_CUBLAS)
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
-                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
 
                 // compute
                 CUBLAS_CHECK(
@@ -8455,25 +8235,27 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         }
 
 #if defined(GGML_USE_CUBLAS)
-        ggml_fp16_t * const wdata = params->wdata;
-
         const float alpha = 1.0f;
         const float beta = 0.0f;
-        const int x_ne = ne01 * ne10;
+        const int x_ne = ne01 * ne00;
         const int y_ne = ne11 * ne10;
         const int d_ne = ne11 * ne01;
 
         size_t x_size, y_size, d_size;
-        float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
-        float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
-        float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+        ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float       * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
 #else
         float * const wdata = params->wdata;
 #endif
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
 #if defined(GGML_USE_CUBLAS)
+                // copy src0 while converting src1
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
+
                 // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
+                ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02);
                 {
                     size_t id = 0;
                     for (int64_t i01 = 0; i01 < ne11; ++i01) {
@@ -8494,13 +8276,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 #endif
 
 #if defined(GGML_USE_CUBLAS)
-                const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
                 const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
-
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
 
                 // compute
@@ -8719,42 +8498,19 @@ static void ggml_compute_forward_mul_mat_q_f32(
 #if defined(GGML_USE_CUBLAS)
         const float alpha = 1.0f;
         const float beta = 0.0f;
-        const int x_ne = ne01 * ne10;
+        const int x_ne = ne01 * ne00;
         const int y_ne = ne11 * ne10;
         const int d_ne = ne11 * ne01;
 
         size_t x_size, y_size, d_size, q_size;
-        float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
-        float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
-        float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
-        float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
+        float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+        void  * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
 
-        void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream)  = NULL;
-        if (type == GGML_TYPE_Q4_0) {
-            dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
-        }
-        else if (type == GGML_TYPE_Q4_1) {
-            dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
-        }
-        else if (type == GGML_TYPE_Q4_2) {
-            dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
-        }
-        else if (type == GGML_TYPE_Q4_3) {
-            dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
-        }
-        else if (type == GGML_TYPE_Q5_0) {
-            dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
-        }
-        else if (type == GGML_TYPE_Q5_1) {
-            dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
-        }
-        else if (type == GGML_TYPE_Q8_0) {
-            dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
-        }
-        else {
-            GGML_ASSERT(false);
-        }
-#elif !defined(GGML_USE_CLBLAST)
+        const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type);
+        GGML_ASSERT(dequantize_row_q_cuda != NULL);
+#else
         float * const wdata = params->wdata;
         dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
 #endif
@@ -8767,12 +8523,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
 #if defined(GGML_USE_CUBLAS)
                 // copy and dequantize on device
-                CUDA_CHECK(
-                    cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
-                        GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2));
 
-                dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
+                dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2);
                 CUDA_CHECK(cudaGetLastError());
+                CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2));
 #elif defined(GGML_USE_CLBLAST)
                 const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
 #else
@@ -8786,10 +8541,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
                 const float * x = wdata;
 #endif
 
-
 #if defined(GGML_USE_CUBLAS)
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
+
+                // wait for dequantization
+                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0));
 
                 // compute
                 CUBLAS_CHECK(
@@ -8914,7 +8671,6 @@ static void ggml_compute_forward_mul_mat(
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
-        case GGML_TYPE_Q4_3:
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
@@ -9146,7 +8902,6 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
-        case GGML_TYPE_Q4_3:
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
@@ -9472,7 +9227,6 @@ static void ggml_compute_forward_alibi(
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
-        case GGML_TYPE_Q4_3:
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
@@ -11817,7 +11571,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1; // TODO: this actually is doing nothing
                                                    //       the threads are still spinning
-                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*MAX(ggml_nelements(node->src1), ggml_nelements(node->src0));
                                 //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
                                 //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
                                 //printf("cur = %zu\n", cur);
@@ -11829,6 +11583,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 #endif
                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
                             cur = 0;
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                node->n_tasks = 1;
+                            }
+#endif
                         } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
@@ -13088,29 +12847,6 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
     return (n/QK4_2*sizeof(block_q4_2));
 }
 
-size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
-    assert(k % QK4_3 == 0);
-    const int nb = k / QK4_3;
-
-    for (int j = 0; j < n; j += k) {
-        block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
-
-        quantize_row_q4_3_reference(src + j, y, k);
-
-        for (int i = 0; i < nb; i++) {
-            for (int l = 0; l < QK4_3; l += 2) {
-                const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
-                const uint8_t vi1 = y[i].qs[l/2] >> 4;
-
-                hist[vi0]++;
-                hist[vi1]++;
-            }
-        }
-    }
-
-    return (n/QK4_3*sizeof(block_q4_3));
-}
-
 size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
     assert(k % QK5_0 == 0);
     const int nb = k / QK5_0;
@@ -13213,12 +12949,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
                 block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
                 result = ggml_quantize_q4_2(src + start, block, n, n, hist);
             } break;
-        case GGML_TYPE_Q4_3:
-            {
-                GGML_ASSERT(start % QK4_3 == 0);
-                block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
-                result = ggml_quantize_q4_3(src + start, block, n, n, hist);
-            } break;
         case GGML_TYPE_Q5_0:
             {
                 GGML_ASSERT(start % QK5_0 == 0);