]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: Prefer vector flash decoding kernel for Gemma models (llama/12738)
authorGaurav Garg <redacted>
Thu, 3 Apr 2025 16:20:29 +0000 (21:50 +0530)
committerGeorgi Gerganov <redacted>
Thu, 24 Apr 2025 17:39:16 +0000 (20:39 +0300)
* Prefer vector flash decoding kernel for Gemma models

Vector flash decoding kernel was not being picked for models with head dimension 256. Gemma models are in this category.
Removing this limit improves e2e performance by upto 12% in gen phase throughput for Gemm models.

* Update ggml/src/ggml-cuda/fattn.cu

Co-authored-by: Johannes Gäßler <redacted>
---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/fattn.cu

index 8edc12649aa63573c3cca0068daa672fa0f0db31..7a2d1e45365af8676f126f611e38230dbd3f7e38 100644 (file)
@@ -299,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
     const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
     const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
-    const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128);
+    const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
     if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);