]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
mla : make the V tensor a view of K (llama/18986)
authorGeorgi Gerganov <redacted>
Thu, 22 Jan 2026 20:09:01 +0000 (22:09 +0200)
committerGeorgi Gerganov <redacted>
Fri, 30 Jan 2026 13:56:40 +0000 (15:56 +0200)
* mla : pass V as a view of K to the FA op

* cuda : adjust mla logic to new layout

* kv-cache : fix rope shift

* tests : remove comment

* cuda : fix reusable_cutoff

Co-authored-by: Johannes Gäßler <redacted>
---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/fattn-mma-f16.cuh

index 8468ba8488d8c9028e157846daf5224074472c80..a781fb91f5b450cfcd3ea006fa4d1f5c6fa9dc08 100644 (file)
@@ -778,12 +778,15 @@ void launch_fattn(
 ) {
     constexpr int ncols = ncols1 * ncols2;
 
-    const bool is_mla = DV == 512; // TODO better parameterization
-
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
 
+    // TODO: make this more generic by removing the notion of "MLA".
+    //       for example "is V a view of K?" so we can skip loading it.
+    //       V strides should be driven by V itself and avoid assumption of the data layout
+    const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K;
+
     GGML_ASSERT(V || is_mla);
 
     const ggml_tensor * mask  = dst->src[3];
index 8cca89c2bfa7844cd61105e2d92fff55f6939e47..203569e3459bb34888937be6d5643a9a8ce3c2a0 100644 (file)
@@ -794,7 +794,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     // For MLA K and V have the same data.
     // Therefore, iterate over V in reverse and re-use the data if possible.
     static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
-    constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
+    // constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV;
+    constexpr int reusable_cutoff = DV; // TODO implement properly
 #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
     T_A_VKQ A_identity;
     make_identity_mat(A_identity);
@@ -1552,7 +1553,7 @@ static __global__ void flash_attn_ext_f16(
             (const half *) (mask + nb33*(sequence % ne33));
         float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
 
-        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+        const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
         const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
         const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
@@ -1596,7 +1597,7 @@ static __global__ void flash_attn_ext_f16(
         (const half *) (mask + nb33*(sequence % ne33));
     float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
 
-    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+    const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
     const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
     const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;