From: Georgi Gerganov Date: Thu, 22 Jan 2026 20:09:01 +0000 (+0200) Subject: mla : make the V tensor a view of K (llama/18986) X-Git-Tag: v0.9.6~40 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=91d417c5954c1cb45d83795f23c8b458730731c1;p=pkg%2Fggml%2Fsources%2Fggml mla : make the V tensor a view of K (llama/18986) * mla : pass V as a view of K to the FA op * cuda : adjust mla logic to new layout * kv-cache : fix rope shift * tests : remove comment * cuda : fix reusable_cutoff Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- diff --git a/src/ggml-cuda/fattn-common.cuh b/src/ggml-cuda/fattn-common.cuh index 8468ba84..a781fb91 100644 --- a/src/ggml-cuda/fattn-common.cuh +++ b/src/ggml-cuda/fattn-common.cuh @@ -778,12 +778,15 @@ void launch_fattn( ) { constexpr int ncols = ncols1 * ncols2; - const bool is_mla = DV == 512; // TODO better parameterization - const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + // TODO: make this more generic by removing the notion of "MLA". + // for example "is V a view of K?" so we can skip loading it. + // V strides should be driven by V itself and avoid assumption of the data layout + const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K; + GGML_ASSERT(V || is_mla); const ggml_tensor * mask = dst->src[3]; diff --git a/src/ggml-cuda/fattn-mma-f16.cuh b/src/ggml-cuda/fattn-mma-f16.cuh index 8cca89c2..203569e3 100644 --- a/src/ggml-cuda/fattn-mma-f16.cuh +++ b/src/ggml-cuda/fattn-mma-f16.cuh @@ -794,7 +794,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // For MLA K and V have the same data. // Therefore, iterate over V in reverse and re-use the data if possible. static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); - constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; + // constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV; + constexpr int reusable_cutoff = DV; // TODO implement properly #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) T_A_VKQ A_identity; make_identity_mat(A_identity); @@ -1552,7 +1553,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; @@ -1596,7 +1597,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9f61c648..146d05f5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6122,7 +6122,19 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache ggml_set_name(k, "k"); - ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache + ggml_tensor * v = nullptr; + if (hsk_padded == 576 && hsv_padded == 512) { + // TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes + + // in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models + // for more info: + // - https://github.com/ggml-org/llama.cpp/pull/13435 + // - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392 + // - https://github.com/ggml-org/llama.cpp/pull/18986 + v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0); + } else { + v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache + } ggml_set_name(v, "v"); ggml_tensor * m = nullptr;