]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix matrix multiplication algorithm choice (#8102)
authorJohannes Gäßler <redacted>
Mon, 24 Jun 2024 23:22:33 +0000 (01:22 +0200)
committerGitHub <redacted>
Mon, 24 Jun 2024 23:22:33 +0000 (01:22 +0200)
ggml-cuda.cu

index 2dda03924253196d6d6fa4a26af7029dd90a1fdd..0acfda91d3e51b556d5c704abf0e12fe5f8f6ea6 100644 (file)
@@ -1924,16 +1924,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
         // FP32 precision KQV single-batch for batch size 1 without FlashAttention
         ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
+    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
+               && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+        // KQ + KQV multi-batch without FlashAttention
+        ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
     } else if (use_dequantize_mul_mat_vec) {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
     } else if (use_mul_mat_vec_q) {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
     } else if (use_mul_mat_q) {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
-    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
-               && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
-        // KQ + KQV multi-batch without FlashAttention
-        ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
     } else {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
     }