]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: lower GPU latency + fix Windows performance (#3110)
authorJohannes Gäßler <redacted>
Mon, 11 Sep 2023 17:55:51 +0000 (19:55 +0200)
committerGitHub <redacted>
Mon, 11 Sep 2023 17:55:51 +0000 (19:55 +0200)
ggml-cuda.cu

index 50344ae87ae127adfb794bbb32bd6ea78a6ac73b..9e9eac487103e0f38bf32548d11ab3d0e1d3f1a4 100644 (file)
@@ -221,10 +221,13 @@ typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__
 typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
 typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
 typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
-typedef void (*ggml_cuda_op_t)(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i,
-    float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main);
+typedef void (*ggml_cuda_op_mul_mat_t)(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, const cudaStream_t & stream);
+typedef void (*ggml_cuda_op_flatten_t)(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream);
 
 // QK = number of values after dequantization
 // QR = QK / number of values before dequantization
@@ -405,11 +408,29 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
 #endif
 
+#define MUL_MAT_SRC1_COL_STRIDE 128
+
+#define MAX_STREAMS 8
+static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
+
 struct ggml_tensor_extra_gpu {
     void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
-    cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
+    cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
 };
 
+// this is faster on Windows
+// probably because the Windows CUDA libraries forget to make this check before invoking the drivers
+inline cudaError_t ggml_cuda_set_device(const int device) {
+    int current_device;
+    CUDA_CHECK(cudaGetDevice(&current_device));
+
+    if (device == current_device) {
+        return cudaSuccess;
+    }
+
+    return cudaSetDevice(device);
+}
+
 static int g_device_count = -1;
 static int g_main_device = 0;
 static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
@@ -422,8 +443,6 @@ static size_t g_scratch_offset = 0;
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
-static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
-
 static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -5139,25 +5158,27 @@ void ggml_init_cublas() {
         GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
         int64_t total_vram = 0;
         fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
-        for (int id = 0; id < g_device_count; ++id) {
+        for (int64_t id = 0; id < g_device_count; ++id) {
             cudaDeviceProp prop;
             CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
-            fprintf(stderr, "  Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
+            fprintf(stderr, "  Device %ld: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
 
             g_tensor_split[id] = total_vram;
             total_vram += prop.totalGlobalMem;
 
             g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
         }
-        for (int id = 0; id < g_device_count; ++id) {
+        for (int64_t id = 0; id < g_device_count; ++id) {
             g_tensor_split[id] /= total_vram;
         }
 
-        for (int id = 0; id < g_device_count; ++id) {
-            CUDA_CHECK(cudaSetDevice(id));
+        for (int64_t id = 0; id < g_device_count; ++id) {
+            CUDA_CHECK(ggml_cuda_set_device(id));
 
-            // create main stream
-            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
+            // create cuda streams
+            for (int64_t is = 0; is < MAX_STREAMS; ++is) {
+                CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[id][is], cudaStreamNonBlocking));
+            }
 
             // create cublas handle
             CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
@@ -5265,225 +5286,169 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
 }
 
 inline void ggml_cuda_op_add(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
-
-    GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
-    GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i  != nullptr);
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
 
-    // compute
     if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
+        add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-        add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
+        add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
     } else {
         GGML_ASSERT(false);
     }
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) i02;
-    (void) i1;
 }
 
 inline void ggml_cuda_op_mul(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
-
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i  != nullptr);
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
 
-    mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
+    mul_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
 
     (void) dst;
-    (void) src0_ddq_i;
-    (void) i02;
-    (void) i1;
 }
 
 inline void ggml_cuda_op_gelu(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    // compute
-    gelu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
+    gelu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_silu(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    // compute
-    silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
+    silu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_norm(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t nrows = ggml_nrows(src0);
 
-    // compute
-    norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+    norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_rms_norm(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t nrows = ggml_nrows(src0);
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
 
-    // compute
-    rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
+    rms_norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream);
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_mul_mat_q(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
-
-    GGML_ASSERT(src0_ddq_i != nullptr);
-    GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, const cudaStream_t & stream) {
 
     const int64_t ne00 = src0->ne[0];
 
     const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
     GGML_ASSERT(ne10 % QK8_1 == 0);
 
     const int64_t ne0 = dst->ne[0];
 
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t row_diff = row_high - row_low;
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
 
     // the main device has a larger memory buffer to hold the results from all GPUs
     // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
-    const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
-
-    const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ?
-        ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
-    size_t as;
-    void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne11*sizeof(block_q8_1)/QK8_1, &as);
-    quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, cudaStream_main);
+    const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
 
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
-            ggml_mul_mat_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q4_1:
-            ggml_mul_mat_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q5_0:
-            ggml_mul_mat_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q5_1:
-            ggml_mul_mat_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q8_0:
-            ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q2_K:
-            ggml_mul_mat_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q3_K:
-            ggml_mul_mat_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q4_K:
-            ggml_mul_mat_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q5_K:
-            ggml_mul_mat_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q6_K:
-            ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         default:
             GGML_ASSERT(false);
             break;
     }
 
-    ggml_cuda_pool_free(src1_q8_1, as);
-
     (void) src1;
     (void) dst;
-    (void) src0_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_ddf_i;
 }
 
 static int64_t get_row_rounding(ggml_type type) {
@@ -5517,168 +5482,144 @@ static int64_t get_row_rounding(ggml_type type) {
     }
 }
 
-inline void ggml_cuda_op_mul_mat_vec(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
-
-    GGML_ASSERT(src0_ddq_i != nullptr);
-    GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+inline void ggml_cuda_op_mul_mat_vec_q(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, const cudaStream_t & stream) {
 
     const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = i01_high - i01_low;
+    const int64_t row_diff = row_high - row_low;
 
-#ifdef GGML_CUDA_FORCE_DMMV
-    const bool use_mul_mat_vec_q = false;
-    (void) g_compute_capabilities[0];
-#else
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
 
-    bool mul_mat_vec_q_implemented =
-        src0->type == GGML_TYPE_Q4_0 ||
-        src0->type == GGML_TYPE_Q4_1 ||
-        src0->type == GGML_TYPE_Q5_0 ||
-        src0->type == GGML_TYPE_Q5_1 ||
-        src0->type == GGML_TYPE_Q8_0;
-#if QK_K == 256
-    mul_mat_vec_q_implemented = mul_mat_vec_q_implemented ||
-        src0->type == GGML_TYPE_Q2_K ||
-        src0->type == GGML_TYPE_Q3_K ||
-        src0->type == GGML_TYPE_Q4_K ||
-        src0->type == GGML_TYPE_Q5_K ||
-        src0->type == GGML_TYPE_Q6_K;
-#endif // QK_K == 256
-
-    const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= MIN_CC_DP4A && mul_mat_vec_q_implemented;
-#endif
+    (void) src1;
+    (void) dst;
+    (void) src1_ddf_i;
+    (void) src1_ncols;
+    (void) src1_padded_row_size;
+}
 
-    if (use_mul_mat_vec_q) {
-        const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ?
-            ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
-        size_t as;
-        void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
-        quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, 1, padded_row_size, cudaStream_main);
-
-        switch (src0->type) {
-            case GGML_TYPE_Q4_0:
-                mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q4_1:
-                mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q5_0:
-                mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q5_1:
-                mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q8_0:
-                mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q2_K:
-                mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q3_K:
-                mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q4_K:
-                mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q5_K:
-                mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q6_K:
-                mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+inline void ggml_cuda_op_dequantize_mul_mat_vec(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, const cudaStream_t & stream) {
 
-        ggml_cuda_pool_free(src1_q8_1, as);
-    } else {
-        // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
+    const int64_t ne00 = src0->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
 #ifdef GGML_CUDA_F16
-        size_t ash;
-        dfloat * src1_dfloat = nullptr; // dfloat == half
-
-        bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
-            src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
-            src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
-
-        if (src1_convert_f16) {
-            src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
-            ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
-                                    ne00, 1, sizeof(float), 0, 0,
-                                    ne00, 1, sizeof(half),  0, 0, cudaStream_main);
-        }
+    size_t ash;
+    dfloat * src1_dfloat = nullptr; // dfloat == half
+
+    bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+        src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
+        src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+
+    if (src1_convert_f16) {
+        src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
+        ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
+                                ne00, 1, sizeof(float), 0, 0,
+                                ne00, 1, sizeof(half),  0, 0, stream);
+    }
 #else
-        dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
+    const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
 #endif // GGML_CUDA_F16
 
-        switch (src0->type) {
-            case GGML_TYPE_Q4_0:
-                dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q4_1:
-                dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q5_0:
-                dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q5_1:
-                dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q8_0:
-                dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q2_K:
-                dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q3_K:
-                dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q4_K:
-                dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q5_K:
-                dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q6_K:
-                dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_F16:
-                convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_F16:
+            convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
 
 #ifdef GGML_CUDA_F16
-        if (src1_convert_f16) {
-            ggml_cuda_pool_free(src1_dfloat, ash);
-        }
-#endif // GGML_CUDA_F16
+    if (src1_convert_f16) {
+        ggml_cuda_pool_free(src1_dfloat, ash);
     }
+#endif // GGML_CUDA_F16
 
     (void) src1;
     (void) dst;
-    (void) src0_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_ddq_i;
+    (void) src1_ncols;
+    (void) src1_padded_row_size;
 }
 
 inline void ggml_cuda_op_mul_mat_cublas(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, const cudaStream_t & stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src0_dd_i != nullptr);
     GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(dst_dd_i != nullptr);
 
     const float alpha = 1.0f;
     const float beta = 0.0f;
@@ -5686,43 +5627,48 @@ inline void ggml_cuda_op_mul_mat_cublas(
     const int64_t ne00 = src0->ne[0];
 
     const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
 
     const int64_t ne0 = dst->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t row_diff = row_high - row_low;
+
+    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
+    size_t src0_as;
+    float * src0_ddf_i = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as);
+    to_fp32_cuda(src0_dd_i, src0_ddf_i, row_diff*ne00, stream);
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
 
     // the main device has a larger memory buffer to hold the results from all GPUs
     // ldc == nrows of the matrix that cuBLAS writes into
-    int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
+    int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
 
-    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], cudaStream_main));
+    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
     CUBLAS_CHECK(
         cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
-                i01_diff, ne11, ne10,
+                row_diff, src1_ncols, ne10,
                 &alpha, src0_ddf_i, ne00,
-                        src1_ddf_i, ne10,
-                &beta,  dst_ddf_i,  ldc));
+                        src1_ddf_i,  ne10,
+                &beta,  dst_dd_i,   ldc));
+
+    ggml_cuda_pool_free(src0_ddf_i, src0_as);
 
     (void) dst;
-    (void) src0_ddq_i;
-    (void) i02;
-    (void) i1;
+    (void) src0_dd_i;
+    (void) src1_ddq_i;
+    (void) src1_padded_row_size;
 }
 
 inline void ggml_cuda_op_rope(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t nrows = ggml_nrows(src0);
 
     const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_dims = ((int32_t *) dst->op_params)[1];
@@ -5742,33 +5688,30 @@ inline void ggml_cuda_op_rope(
 
     // compute
     if (is_glm) {
-        rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, n_ctx, cudaStream_main);
+        rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream);
     } else if (is_neox) {
         GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
-        rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
+        rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
     } else {
-        rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
+        rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
     }
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_alibi(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t nrows = ggml_nrows(src0);
 
     const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_head = ((int32_t *) dst->op_params)[1];
@@ -5783,334 +5726,355 @@ inline void ggml_cuda_op_alibi(
     const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
-    // compute
-    alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main);
+    alibi_f32_cuda(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream);
 
     (void) src1;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_diag_mask_inf(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int nrows0 = ggml_nrows(src0);
 
     const int n_past = ((int32_t *) dst->op_params)[0];
 
-    // compute
-    diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
+    diag_mask_inf_f32_cuda(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_soft_max(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t nrows = ggml_nrows(src0);
 
-    // compute
-    soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+    soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_scale(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const float scale = ((float *) src1->data)[0];
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
-
-    // compute
-    scale_f32_cuda(src0_ddf_i, dst_ddf_i, scale, ne00*i01_diff, cudaStream_main);
+    scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
     CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
+}
+
+static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) {
+    const int64_t nrows0 = ggml_nrows(src0);
+
+    const bool use_src1 = src1 != nullptr;
+    const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+
+    GGML_ASSERT(             src0->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(              dst->backend != GGML_BACKEND_GPU_SPLIT);
+
+    struct ggml_tensor_extra_gpu * src0_extra =            (ggml_tensor_extra_gpu *) src0->extra;
+    struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
+    struct ggml_tensor_extra_gpu * dst_extra  =            (ggml_tensor_extra_gpu *)  dst->extra;
+
+    const bool src0_on_device =             src0->backend == GGML_BACKEND_GPU;
+    const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
+    const bool  dst_on_device =              dst->backend == GGML_BACKEND_GPU;
+
+    const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE;
+
+    // dd = data device
+    float * src0_ddf = nullptr;
+    float * src1_ddf = nullptr;
+    float *  dst_ddf = nullptr;
+
+    // as = actual size
+    size_t src0_asf = 0;
+    size_t src1_asf = 0;
+    size_t  dst_asf = 0;
+
+    ggml_cuda_set_device(g_main_device);
+    const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
+
+    if (src0_on_device) {
+        src0_ddf = (float *) src0_extra->data_device[g_main_device];
+    } else {
+        src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf);
+        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
+    }
+
+    if (use_src1 && !src1_stays_on_host) {
+        if (src1_on_device) {
+            src1_ddf = (float *) src1_extra->data_device[g_main_device];
+        } else {
+            src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf);
+            CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
+        }
+    }
+    if (dst_on_device) {
+        dst_ddf = (float *) dst_extra->data_device[g_main_device];
+    } else {
+        dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf);
+    }
+
+    // do the computation
+    op(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
+    CUDA_CHECK(cudaGetLastError());
+
+    // copy dst to host if necessary
+    if (!dst_on_device) {
+        CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
+    }
+
+    if (src0_asf > 0) {
+        ggml_cuda_pool_free(src0_ddf, src0_asf);
+    }
+    if (src1_asf > 0) {
+        ggml_cuda_pool_free(src1_ddf, src1_asf);
+    }
+    if (dst_asf > 0) {
+        ggml_cuda_pool_free(dst_ddf, dst_asf);
+    }
+
+    if (dst->backend == GGML_BACKEND_CPU) {
+        CUDA_CHECK(cudaDeviceSynchronize());
+    }
 }
 
-static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-                         ggml_cuda_op_t op, bool src0_needs_f32, bool flatten_rows) {
+static void ggml_cuda_op_mul_mat(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
+    const bool convert_src1_to_q8_1) {
+
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
     const int64_t nrows0 = ggml_nrows(src0);
 
-    const bool use_src1 = src1 != nullptr;
-    const int64_t ne10 = use_src1 ? src1->ne[0] : 1;
-    const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
-    const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
-    const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
-    const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+    const int64_t nrows1 = ggml_nrows(src1);
 
     GGML_ASSERT(ne03 == ne13);
 
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
 
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
+    const int nb2 = dst->nb[2];
+    const int nb3 = dst->nb[3];
 
     GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
-    GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
 
-    // strides for iteration over dims 3 and 2
-    const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13;
-    const int64_t num_iters = flatten_rows ? 1 : num_iters_0;
-    const int64_t stride_mod = flatten_rows ? num_iters_0 : 1;
-    const int64_t src0_stride = ne00 * ne01 * stride_mod;
-    const int64_t src1_stride = ne10 * ne11 * stride_mod;
-    const int64_t dst_stride = ne0 * ne1 * stride_mod;
+    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
 
-    const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
-    const int64_t i03_max = flatten_rows ? 1 : ne03;
-    const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12);
-    const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02;
-    GGML_ASSERT(!(flatten_rows && ne02 < ne12));
+    const int64_t i02_divisor = ne12 / ne02;
 
     const size_t src0_ts = ggml_type_size(src0->type);
     const size_t src0_bs = ggml_blck_size(src0->type);
+    const size_t q8_1_ts = sizeof(block_q8_1);
+    const size_t q8_1_bs = QK8_1;
 
-    struct ggml_tensor_extra_gpu * src0_extra =            (ggml_tensor_extra_gpu *) src0->extra;
-    struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
-    struct ggml_tensor_extra_gpu * dst_extra  =            (ggml_tensor_extra_gpu *) dst->extra;
+    struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+    struct ggml_tensor_extra_gpu *  dst_extra = (ggml_tensor_extra_gpu *)  dst->extra;
 
     const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
     const bool src0_is_contiguous = ggml_is_contiguous(src0);
-    const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
 
-    const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1);
-    const bool src1_stays_on_host = use_src1 && (
-        dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
+    const bool src1_is_contiguous = ggml_is_contiguous(src1);
+    const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ?
+        ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
 
     const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
+    GGML_ASSERT(!(split && ne02 > 1));
+    GGML_ASSERT(!(split && ne03 > 1));
     GGML_ASSERT(!(split && ne02 < ne12));
 
-    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
-
     // dd = data device
-    char  * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized
-    float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
-    float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
-    float *  dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
-
-    // asq = actual size quantized, asf = actual size float
-    size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0};
-    size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0};
-    size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
-    size_t  dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
+    char  *  src0_dd[GGML_CUDA_MAX_DEVICES] = {nullptr};
+    float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
+    char  * src1_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // q8_1
+    float *   dst_dd[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
-    // if multiple devices are used they need to wait for the main device
-    // here an event is recorded that signifies that the main device has finished calculating the input data
-    if (split && g_device_count > 1) {
-        CUDA_CHECK(cudaSetDevice(g_main_device));
-        CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device]));
-    }
+    // as = actual size
+    size_t  src0_as[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t src1_asq[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t   dst_as[GGML_CUDA_MAX_DEVICES] = {0};
 
-    for (int id = 0; id < g_device_count; ++id) {
-        if (!split && id != g_main_device) {
-            continue;
-        }
+    int64_t  row_low[GGML_CUDA_MAX_DEVICES];
+    int64_t row_high[GGML_CUDA_MAX_DEVICES];
 
-        const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU && id == g_main_device;
-        const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        // by default, use all rows
+        row_low[id]  = 0;
+        row_high[id] = ne01;
 
-        int64_t row_low, row_high;
+        // for multi GPU, get the row boundaries from tensor split
+        // and round to mul_mat_q tile sizes
         if (split) {
             const int64_t rounding = get_row_rounding(src0->type);
 
-            row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
-            row_low -= row_low % rounding;
+            if (id != 0) {
+                row_low[id]  = ne01*g_tensor_split[id];
+                row_low[id] -= row_low[id] % rounding;
+            }
 
-            if (id == g_device_count - 1) {
-                row_high = nrows0;
-            } else {
-                row_high = nrows0*g_tensor_split[id + 1];
-                row_high -= row_high % rounding;
+            if (id != g_device_count - 1) {
+                row_high[id]  = ne01*g_tensor_split[id + 1];
+                row_high[id] -= row_high[id] % rounding;
             }
-        } else {
-            row_low = 0;
-            row_high = nrows0*i02_divisor;
         }
-        if (row_low == row_high) {
+    }
+
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
             continue;
         }
 
-        int64_t row_diff = row_high - row_low;
+        const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
+        const bool  dst_on_device =  dst->backend == GGML_BACKEND_GPU && id == g_main_device;
 
-        cudaSetDevice(id);
-        cudaStream_t cudaStream_main = g_cudaStreams_main[id];
-
-        // wait for main GPU data if necessary
-        if (split && id != g_main_device) {
-            CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device]));
-        }
+        ggml_cuda_set_device(id);
+        const cudaStream_t stream = g_cudaStreams[id][0];
 
         if (src0_on_device && src0_is_contiguous) {
-            if (src0_is_f32) {
-                src0_ddf[id] = (float *) src0_extra->data_device[id];
-            } else {
-                src0_ddq[id] = (char *) src0_extra->data_device[id];
-            }
+            src0_dd[id] = (char *) src0_extra->data_device[id];
         } else {
-            if (src0_is_f32) {
-                src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
-            } else {
-                src0_ddq[id] = (char *) ggml_cuda_pool_malloc(row_diff*ne00 * src0_ts/src0_bs, &src0_asq[id]);
-            }
+            const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
+            src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
         }
 
-        if (src0_needs_f32 && !src0_is_f32) {
-            src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
+        if (src1_on_device && src1_is_contiguous) {
+            src1_ddf[id] = (float *) src1_extra->data_device[id];
+        } else {
+            src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
         }
 
-        if (use_src1 && !src1_stays_on_host) {
-            if (src1_on_device && src1_is_contiguous) {
-                src1_ddf[id] = (float *) src1_extra->data_device[id];
-            } else {
-                src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]);
+        if (convert_src1_to_q8_1) {
+            src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
+
+            if (split && src1_on_device && src1_is_contiguous) {
+                quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
+                CUDA_CHECK(cudaGetLastError());
             }
         }
+
         if (dst_on_device) {
-            dst_ddf[id] = (float *) dst_extra->data_device[id];
+            dst_dd[id] = (float *) dst_extra->data_device[id];
         } else {
-            size_t size_dst_ddf = split ? row_diff*ne1 * sizeof(float) : num_iters*dst_stride * sizeof(float);
-            dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
+            const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
+            dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
         }
+    }
 
-        for (int64_t i03 = 0; i03 < i03_max; i03++) {
-            const int64_t i13 = i03 % ne13;
-            for (int64_t i02 = 0; i02 < i02_max; i02++) {
-                const int64_t i12 = i02 % ne12;
+    // if multiple devices are used they need to wait for the main device
+    // here an event is recorded that signals that the main device has finished calculating the input data
+    if (split && g_device_count > 1) {
+        CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+        CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
+    }
 
-                const int64_t i0 = i03*i02_max + i02;
+    const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
+    for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
+        const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
+        const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
 
-                // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
-                const int64_t i0_offset_low = row_low/rows_per_iter;
-                const int64_t i0_offset_high = row_high/rows_per_iter;
+        for (int64_t id = 0; id < g_device_count; ++id) {
+            if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
+                continue;
+            }
 
-                int64_t i01_low = 0;
-                int64_t i01_high = rows_per_iter;
-                if (split) {
-                    if (i0 < i0_offset_low || i0 > i0_offset_high) {
-                        continue;
-                    }
-                    if (i0 == i0_offset_low) {
-                        i01_low = row_low % rows_per_iter;
-                    }
-                    if (i0 == i0_offset_high) {
-                        i01_high = row_high % rows_per_iter;
-                    }
-                }
+            const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
+            const bool  dst_on_device =  dst->backend == GGML_BACKEND_GPU && id == g_main_device;
+            const int64_t row_diff = row_high[id] - row_low[id];
 
-                // There is possibly a bug in the Windows nvcc compiler regarding instruction reordering or optimizing out local variables.
-                // Removing the first assert or changing the order of the arguments causes the second assert to fail.
-                // Removing both asserts results in i01_high becoming 0 which in turn results in garbage output.
-                // The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU).
-                GGML_ASSERT(i01_low == 0 || g_device_count > 1);
-                GGML_ASSERT(i01_high == rows_per_iter || g_device_count > 1);
+            ggml_cuda_set_device(id);
+            const cudaStream_t stream = g_cudaStreams[id][is];
 
-                const int64_t i01_diff = i01_high - i01_low;
-                if (i01_diff == 0) {
-                    continue;
-                }
-                const int64_t i11 = i13*ne12 + i12;
+            // wait for main GPU data if necessary
+            if (split && (id != g_main_device || is != 0)) {
+                CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0]));
+            }
+
+            for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
+                const int64_t i03 = i0 / ne12;
+                const int64_t i02 = i0 % ne12;
+
+                const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
 
                 // for split tensors the data begins at i0 == i0_offset_low
-                char  * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs;
-                float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride;
-                float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
-                float * dst_ddf_i  =  dst_ddf[id] + (i0             - i0_offset_low)*dst_stride;
-
-                // for split tensors the data pointer needs to be rounded down
-                // to the bin edge for i03, i02 bins beyond the first
-                if (i0 - i0_offset_low > 0) {
-                    GGML_ASSERT(!flatten_rows);
-                    src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs;
-                    src0_ddf_i -= (row_low % ne01)*ne00;
-                    dst_ddf_i  -= (row_low % ne0)*ne1;
-                }
+                char  *  src0_dd_i =  src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs;
+                float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10;
+                char  * src1_ddq_i = src1_ddq[id] +  src1_ddq_i_offset;
+                float *   dst_dd_i =   dst_dd[id] + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
 
                 // the main device memory buffer can be on VRAM scratch, with space for all partial results
                 // in that case an offset on dst_ddf_i is needed
                 if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) {
-                    dst_ddf_i += i01_low; // offset is 0 if no tensor split
+                    dst_dd_i += row_low[id]; // offset is 0 if no tensor split
                 }
 
                 // copy src0, src1 to device if necessary
-                if (use_src1 && !src1_stays_on_host) {
-                    if (src1->backend == GGML_BACKEND_CPU) {
-                        GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
-                        int64_t nrows1 = flatten_rows ? nrows0 : ne11;
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
-                    } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
-                        if (id != g_main_device) {
-                            GGML_ASSERT(!flatten_rows);
+                if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
+                    if (id != g_main_device) {
+                        if (convert_src1_to_q8_1) {
+                            char * src1_ddq_i_source = src1_ddq[g_main_device] + src1_ddq_i_offset;
+                            CUDA_CHECK(cudaMemcpyAsync(src1_ddq_i, src1_ddq_i_source, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs,
+                                                    cudaMemcpyDeviceToDevice, stream));
+                        } else {
                             float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
-                            src1_ddf_i_source += i11*src1_stride;
-                            CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
-                                                    cudaMemcpyDeviceToDevice, cudaStream_main));
+                            src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
+                            CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_ncols*ne10*sizeof(float),
+                                                    cudaMemcpyDeviceToDevice, stream));
                         }
-                    } else if (src1_on_device && !src1_is_contiguous) {
-                        GGML_ASSERT(!split);
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main));
-                    } else {
-                        GGML_ASSERT(false);
                     }
+                } else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) {
+                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
+                                   src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
+                } else {
+                    GGML_ASSERT(false);
                 }
 
-                if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
-                    if (src0_is_f32) {
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
-                    } else {
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
-                    }
+                if (convert_src1_to_q8_1 && src1->backend == GGML_BACKEND_CPU) {
+                    quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
+                    CUDA_CHECK(cudaGetLastError());
                 }
 
-                // convert src0 to f32 if it is necessary for the ggml_cuda_op
-                if (src0_needs_f32 && !src0_is_f32) {
-                    to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main);
-                    CUDA_CHECK(cudaGetLastError());
+                if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
+                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, row_low[id], row_high[id], stream));
                 }
 
                 // do the computation
-                op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
+                op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
+                   row_low[id], row_high[id], src1_ncols, src1_padded_col_size, stream);
                 CUDA_CHECK(cudaGetLastError());
 
                 // copy dst to host or other device if necessary
@@ -6132,95 +6096,86 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                         // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
                         // Instead they need to be copied to the correct slice in ne0 = dst row index.
                         // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
-                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i01_low*sizeof(float) + i02*nb2 + i03*nb3);
-                        CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_ddf_i, i01_diff*sizeof(float),
-                                                     i01_diff*sizeof(float), ne1, kind, cudaStream_main));
+                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+                        dhf_dst_i += src1_col_0*ne0 + row_low[id];
+                        CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float),
+                                                    row_diff*sizeof(float), src1_ncols, kind, stream));
                     } else {
                         float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
-                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
+                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+                        dhf_dst_i += src1_col_0*ne0;
+                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), kind, stream));
                     }
                 }
 
-                // signify to main device that other device is done
-                if (split && g_device_count > 1 && id != g_main_device) {
-                    CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main));
+                // add event for the main device to wait on until other device is done
+                if (split && (id != g_main_device || is != 0)) {
+                    CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream));
                 }
             }
         }
     }
 
-    // wait until each device is finished, then free their buffers
-    for (int id = 0; id < g_device_count; ++id) {
-        if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
-            continue;
-        }
-
-        CUDA_CHECK(cudaSetDevice(id));
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        CUDA_CHECK(ggml_cuda_set_device(id));
 
-        if (src0_asq[id] > 0) {
-            ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
-        }
-        if (src0_asf[id] > 0) {
-            ggml_cuda_pool_free(src0_ddf[id], src0_asf[id]);
+        // free buffers again when done
+        if (src0_as[id] > 0) {
+            ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
         }
         if (src1_asf[id] > 0) {
             ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
         }
-        if (dst_asf[id] > 0) {
-            ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
+        if (src1_asq[id] > 0) {
+            ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
+        }
+        if (dst_as[id] > 0) {
+            ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
         }
     }
 
     // main device waits for all other devices to be finished
     if (split && g_device_count > 1) {
-        CUDA_CHECK(cudaSetDevice(g_main_device));
-        for (int id = 0; id < g_device_count; ++id) {
-            if (id != g_main_device && src0_extra->events[id]) {
-                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id]));
+        int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
+        is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS;
+
+        CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+        for (int64_t id = 0; id < g_device_count; ++id) {
+            for (int64_t is = 0; is < is_max; ++is) {
+                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is]));
             }
         }
     }
 
     if (dst->backend == GGML_BACKEND_CPU) {
-        CUDA_CHECK(cudaSetDevice(g_main_device));
+        CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         CUDA_CHECK(cudaDeviceSynchronize());
     }
 }
 
 void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
-    // Due to flatten_rows == true this does in practice not make a difference however.
-    // Better solution would be nice but right now that would require disproportionate changes.
-    GGML_ASSERT(
-        (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
-        src1->type == GGML_TYPE_F32 &&
-        (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
 }
 
 void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
 }
 
 void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_gelu, true, true);
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
 }
 
 void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
 }
 
 void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true);
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
 }
 
 void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
 }
 
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
@@ -6254,8 +6209,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
 
     const int64_t ne12 = src1->ne[2];
 
-    CUDA_CHECK(cudaSetDevice(g_main_device));
-    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -6266,7 +6221,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
     struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
     float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
 
-    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main);
+    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
 }
 
 void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -6285,8 +6240,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
     const int64_t nb01 = src0->nb[1];
     const int64_t nb02 = src0->nb[2];
 
-    CUDA_CHECK(cudaSetDevice(g_main_device));
-    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -6297,38 +6252,49 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
     struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
     float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
 
-    const int row_stride_x = nb01 / sizeof(half);
-    const int channel_stride_x = nb02 / sizeof(half);
+    const int64_t row_stride_x = nb01 / sizeof(half);
+    const int64_t channel_stride_x = nb02 / sizeof(half);
 
-    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
+    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
 }
 
 void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
         src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
 
+    int64_t min_compute_capability = INT_MAX;
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        if (min_compute_capability > g_compute_capabilities[id]
+                && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
+            min_compute_capability = g_compute_capabilities[id];
+        }
+    }
+
     if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
         ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
     } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
         ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
     }else if (src0->type == GGML_TYPE_F32) {
-        ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
+        ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
     } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
         if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
-            ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
-        } else {
-            int min_compute_capability = INT_MAX;
-            for (int id = 0; id < g_device_count; ++id) {
-                if (min_compute_capability > g_compute_capabilities[id]
-                        && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
-                    min_compute_capability = g_compute_capabilities[id];
-                }
-            }
 
-            if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
-                ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
+#ifdef GGML_CUDA_FORCE_DMMV
+            const bool use_mul_mat_vec_q = false;
+#else
+            const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
+#endif // GGML_CUDA_FORCE_DMMV
+
+            if (use_mul_mat_vec_q) {
+                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
+            } else {
+                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
+            }
+        } else {
+            if (src1->backend == GGML_BACKEND_GPU && g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
+                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
             } else {
-                ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
+                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
             }
         }
     } else {
@@ -6337,8 +6303,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
 }
 
 void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true, true);
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
 }
 
 void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -6367,8 +6332,8 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
     const int64_t nb11 = src1->nb[1];
     const int64_t nb12 = src1->nb[2];
 
-    CUDA_CHECK(cudaSetDevice(g_main_device));
-    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
@@ -6378,10 +6343,10 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
 
     if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, cudaStream_main);
+                              ne10, ne11, nb10, nb11, nb12, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
         ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, cudaStream_main);
+                              ne10, ne11, nb10, nb11, nb12, main_stream);
     } else {
         GGML_ASSERT(false);
     }
@@ -6395,25 +6360,20 @@ void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
 }
 
 void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true);
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf);
 }
 
 void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true, true);
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max);
 }
 
 void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
     GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
-
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, true);
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope);
 }
 
 void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true);
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
 }
 
 void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -6423,7 +6383,7 @@ void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
 }
 
 void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
-    int nrows = ggml_nrows(tensor);
+    const int64_t nrows = ggml_nrows(tensor);
 
     const int64_t ne0 = tensor->ne[0];
 
@@ -6433,14 +6393,14 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
     struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
     memset(extra, 0, sizeof(*extra));
 
-    for (int id = 0; id < g_device_count; ++id) {
+    for (int64_t id = 0; id < g_device_count; ++id) {
         if (backend == GGML_BACKEND_GPU && id != g_main_device) {
             continue;
         }
 
-        cudaSetDevice(id);
+        ggml_cuda_set_device(id);
 
-        int row_low, row_high;
+        int64_t row_low, row_high;
         if (backend == GGML_BACKEND_GPU) {
             row_low = 0;
             row_high = nrows;
@@ -6490,7 +6450,9 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
         extra->data_device[id] = buf;
 
         if (backend == GGML_BACKEND_GPU_SPLIT) {
-            CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming));
+            for (int64_t is = 0; is < MAX_STREAMS; ++is) {
+                CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
+            }
         }
     }
 
@@ -6504,15 +6466,17 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
 
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
-    for (int id = 0; id < g_device_count; ++id) {
+    for (int64_t id = 0; id < g_device_count; ++id) {
         if (extra->data_device[id] != nullptr) {
-            CUDA_CHECK(cudaSetDevice(id));
+            CUDA_CHECK(ggml_cuda_set_device(id));
             CUDA_CHECK(cudaFree(extra->data_device[id]));
         }
 
-        if (extra->events[id] != nullptr) {
-            CUDA_CHECK(cudaSetDevice(id));
-            CUDA_CHECK(cudaEventDestroy(extra->events[id]));
+        for (int64_t is = 0; is < MAX_STREAMS; ++is) {
+            if (extra->events[id][is] != nullptr) {
+                CUDA_CHECK(ggml_cuda_set_device(id));
+                CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
+            }
         }
     }
 
@@ -6564,7 +6528,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
         force_inplace;
     const size_t size = ggml_nbytes(tensor);
 
-    CUDA_CHECK(cudaSetDevice(g_main_device));
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
     if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
         struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
         char * src0_ddc = (char *) src0_extra->data_device[g_main_device];