]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
mla : make the V tensor a view of K (llama/18986)
authorGeorgi Gerganov <redacted>
Thu, 22 Jan 2026 20:09:01 +0000 (22:09 +0200)
committerGeorgi Gerganov <redacted>
Fri, 30 Jan 2026 11:49:29 +0000 (13:49 +0200)
* mla : pass V as a view of K to the FA op

* cuda : adjust mla logic to new layout

* kv-cache : fix rope shift

* tests : remove comment

* cuda : fix reusable_cutoff

Co-authored-by: Johannes Gäßler <redacted>
---------

Co-authored-by: Johannes Gäßler <redacted>
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-mma-f16.cuh
tests/test-backend-ops.cpp

index 8468ba8488d8c9028e157846daf5224074472c80..a781fb91f5b450cfcd3ea006fa4d1f5c6fa9dc08 100644 (file)
@@ -778,12 +778,15 @@ void launch_fattn(
 ) {
     constexpr int ncols = ncols1 * ncols2;
 
-    const bool is_mla = DV == 512; // TODO better parameterization
-
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
 
+    // TODO: make this more generic by removing the notion of "MLA".
+    //       for example "is V a view of K?" so we can skip loading it.
+    //       V strides should be driven by V itself and avoid assumption of the data layout
+    const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K;
+
     GGML_ASSERT(V || is_mla);
 
     const ggml_tensor * mask  = dst->src[3];
index 8cca89c2bfa7844cd61105e2d92fff55f6939e47..203569e3459bb34888937be6d5643a9a8ce3c2a0 100644 (file)
@@ -794,7 +794,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     // For MLA K and V have the same data.
     // Therefore, iterate over V in reverse and re-use the data if possible.
     static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
-    constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
+    // constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV;
+    constexpr int reusable_cutoff = DV; // TODO implement properly
 #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
     T_A_VKQ A_identity;
     make_identity_mat(A_identity);
@@ -1552,7 +1553,7 @@ static __global__ void flash_attn_ext_f16(
             (const half *) (mask + nb33*(sequence % ne33));
         float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
 
-        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+        const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
         const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
         const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
@@ -1596,7 +1597,7 @@ static __global__ void flash_attn_ext_f16(
         (const half *) (mask + nb33*(sequence % ne33));
     float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
 
-    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+    const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
     const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
     const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
index 9f61c6483de21df9765976490674cc3887a24158..146d05f53bc76d3afd6a8ade8811f4257fde0965 100644 (file)
@@ -6122,7 +6122,19 @@ struct test_flash_attn_ext : public test_case {
         ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,         nr23[1], true); // the K tensor is usually a view of the K cache
         ggml_set_name(k, "k");
 
-        ggml_tensor * v = create_permuted(type_KV,       hsv_padded, kv, nh,         nr23[1], true); // the V tensor is usually a view of the V cache
+        ggml_tensor * v = nullptr;
+        if (hsk_padded == 576 && hsv_padded == 512) {
+            // TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes
+
+            // in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models
+            // for more info:
+            //   - https://github.com/ggml-org/llama.cpp/pull/13435
+            //   - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392
+            //   - https://github.com/ggml-org/llama.cpp/pull/18986
+            v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);
+        } else {
+            v = create_permuted(type_KV,       hsv_padded, kv, nh,         nr23[1], true); // the V tensor is usually a view of the V cache
+        }
         ggml_set_name(v, "v");
 
         ggml_tensor * m = nullptr;