From: Georgi Gerganov Date: Thu, 22 Jan 2026 20:09:01 +0000 (+0200) Subject: mla : make the V tensor a view of K (llama/18986) X-Git-Tag: upstream/1.8.3+155~120 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=3f96a1da0e89e538f508fb641422f5de83bb4dc4;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp mla : make the V tensor a view of K (llama/18986) * mla : pass V as a view of K to the FA op * cuda : adjust mla logic to new layout * kv-cache : fix rope shift * tests : remove comment * cuda : fix reusable_cutoff Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 8468ba84..a781fb91 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -778,12 +778,15 @@ void launch_fattn( ) { constexpr int ncols = ncols1 * ncols2; - const bool is_mla = DV == 512; // TODO better parameterization - const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + // TODO: make this more generic by removing the notion of "MLA". + // for example "is V a view of K?" so we can skip loading it. + // V strides should be driven by V itself and avoid assumption of the data layout + const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K; + GGML_ASSERT(V || is_mla); const ggml_tensor * mask = dst->src[3]; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 8cca89c2..203569e3 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -794,7 +794,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // For MLA K and V have the same data. // Therefore, iterate over V in reverse and re-use the data if possible. static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); - constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; + // constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV; + constexpr int reusable_cutoff = DV; // TODO implement properly #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) T_A_VKQ A_identity; make_identity_mat(A_identity); @@ -1552,7 +1553,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; @@ -1596,7 +1597,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;