]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : remove LLAMA_SET_ROWS checks (#15505)
authorGeorgi Gerganov <redacted>
Thu, 28 Aug 2025 09:27:02 +0000 (12:27 +0300)
committerGitHub <redacted>
Thu, 28 Aug 2025 09:27:02 +0000 (12:27 +0300)
ggml-ci

ggml/src/ggml-cann/common.h
ggml/src/ggml-cann/ggml-cann.cpp
src/llama-context.cpp
src/llama-context.h
src/llama-graph.cpp
src/llama-kv-cache.cpp
src/llama-kv-cache.h

index 33794062f565d5b7473f946a4206d6002c407463..88cc3f481ed3832c191d1a5e8f156677d23669f7 100755 (executable)
@@ -374,7 +374,6 @@ struct ggml_backend_cann_context {
 #endif
     cann_task_queue task_queue;
     bool async_mode;
-    bool support_set_rows;
     // Rope Cache
     void* rope_init_ptr = nullptr;
     void* rope_sin_ptr = nullptr;
@@ -400,14 +399,6 @@ struct ggml_backend_cann_context {
         async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
         GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
             device, async_mode ? "ON" : "OFF");
-
-        support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or(""));
-        GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF");
-
-        if (!support_set_rows) {
-            GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
-                    "Falling back to eager mode.\n", __func__);
-        }
     }
 
     /**
index 81215425618a35466430ce24d542db2a15c788e7..558121dff780b8e536e9ebce62c3a1a43e3c09ec 100755 (executable)
@@ -2251,11 +2251,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(
     bool use_cann_graph = true;
     bool cann_graph_update_required = false;
 
-    // check environment LLAMA_SET_ROWS
-    if (!cann_ctx->support_set_rows) {
-        use_cann_graph = false;
-    }
-
     if (use_cann_graph) {
         if (cann_ctx->cann_graph == nullptr) {
             cann_ctx->cann_graph.reset(new ggml_cann_graph());
index 99bfed75136b2b4e1172136df1ae78fcaea68b15..6b20161a389a0f8a18178dcebb975e1274ebb7d5 100644 (file)
@@ -102,16 +102,6 @@ llama_context::llama_context(
     cparams.op_offload = params.op_offload;
     cparams.kv_unified = params.kv_unified;
 
-    {
-        const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
-        supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows;
-
-        if (!supports_set_rows && !cparams.kv_unified) {
-            LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
-            cparams.kv_unified = true;
-        }
-    }
-
     {
         const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
         graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
@@ -890,12 +880,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
         }
     }
 
-    if (!supports_set_rows) {
-        // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
-        // overlap with device computation.
-        ggml_backend_sched_reset(sched.get());
-    }
-
     // TODO: hacky solution
     if (model.arch == LLM_ARCH_T5 && t_embd) {
         //cross.t_embd = t_embd;
@@ -1226,12 +1210,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
     // wait for the computation to finish (automatically done when obtaining the model output)
     //synchronize();
 
-    if (!supports_set_rows) {
-        // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
-        // overlap with device computation.
-        ggml_backend_sched_reset(sched.get());
-    }
-
     return 0;
 }
 
index 3dd9205446483eb424b7b191ceaf41fc77f012d1..a372bcfbe41aa6883f98b685ed83f729a108c22d 100644 (file)
@@ -283,10 +283,6 @@ private:
 
     bool has_evaluated_once = false;
 
-    // env: LLAMA_SET_ROWS (temporary)
-    // ref: https://github.com/ggml-org/llama.cpp/pull/14285
-    bool supports_set_rows = true;
-
     // env: LLAMA_GRAPH_REUSE_DISABLE
     bool graph_reuse_disable = false;
 
index b928e9e16ead8c6781b7f5effdd3cb1a79424a85..1f2fc3ab62d4e630d54dd45e3ef23662c45b191f 100644 (file)
@@ -314,8 +314,6 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
     res &= self_kq_mask->ne[0] == mctx->get_n_kv();
     res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
 
-    res &= mctx->get_supports_set_rows(); // TODO: tmp
-
     return res;
 }
 
@@ -350,8 +348,6 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
     res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
     res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
 
-    res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
-
     return res;
 }
 
index 920c1d0dbdc745b99ef4d16301c0a63a57e77d28..4485f78d5f5330bacac182aec38cd884583cf3ab 100644 (file)
@@ -197,18 +197,6 @@ llama_kv_cache::llama_kv_cache(
 
     const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
     debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
-
-    const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
-    supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : supports_set_rows;
-
-    if (!supports_set_rows) {
-        // ref: https://github.com/ggml-org/llama.cpp/pull/14363
-        GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support");
-    }
-
-    if (!supports_set_rows) {
-        LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
-    }
 }
 
 void llama_kv_cache::clear(bool data) {
@@ -551,11 +539,8 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
     bool success = true;
 
     for (const auto & ubatch : ubatches) {
-        // non-continuous slots require support for ggml_set_rows()
-        const bool cont = supports_set_rows ? false : true;
-
         // only find a suitable slot for the ubatch. don't modify the cells yet
-        const auto sinfo_new = find_slot(ubatch, cont);
+        const auto sinfo_new = find_slot(ubatch, true);
         if (sinfo_new.empty()) {
             success = false;
             break;
@@ -976,10 +961,6 @@ uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
     return result;
 }
 
-bool llama_kv_cache::get_supports_set_rows() const {
-    return supports_set_rows;
-}
-
 ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
     const int32_t ikv = map_layer_ids.at(il);
 
@@ -1033,36 +1014,26 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
 }
 
 ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
+    GGML_UNUSED(sinfo);
+
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * k = layers[ikv].k;
 
-    const int64_t n_embd_k_gqa = k->ne[0];
     const int64_t n_tokens = k_cur->ne[2];
 
     k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
 
-    if (k_idxs && supports_set_rows) {
-        if (k->ne[2] > 1) {
-            k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
-        }
-
-        return ggml_set_rows(ctx, k, k_cur, k_idxs);
+    if (k->ne[2] > 1) {
+        k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
     }
 
-    // TODO: fallback to old ggml_cpy() method for backwards compatibility
-    //       will be removed when ggml_set_rows() is adopted by all backends
-
-    GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
-
-    ggml_tensor * k_view = ggml_view_1d(ctx, k,
-            n_tokens*n_embd_k_gqa,
-            ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
-
-    return ggml_cpy(ctx, k_cur, k_view);
+    return ggml_set_rows(ctx, k, k_cur, k_idxs);
 }
 
 ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
+    GGML_UNUSED(sinfo);
+
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * v = layers[ikv].v;
@@ -1072,48 +1043,25 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
 
     v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
 
-    if (v_idxs && supports_set_rows) {
-        if (!v_trans) {
-            if (v->ne[2] > 1) {
-                v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
-            }
-
-            return ggml_set_rows(ctx, v, v_cur, v_idxs);
-        }
-
-        // [TAG_V_CACHE_VARIABLE]
-        if (n_embd_v_gqa < v->ne[0]) {
-            v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
+    if (!v_trans) {
+        if (v->ne[2] > 1) {
+            v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
         }
 
-        // the row becomes a single element
-        ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
-
-        v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
-
-        return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
+        return ggml_set_rows(ctx, v, v_cur, v_idxs);
     }
 
-    // TODO: fallback to old ggml_cpy() method for backwards compatibility
-    //       will be removed when ggml_set_rows() is adopted by all backends
-
-    GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
+    // [TAG_V_CACHE_VARIABLE]
+    if (n_embd_v_gqa < v->ne[0]) {
+        v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
+    }
 
-    ggml_tensor * v_view = nullptr;
+    // the row becomes a single element
+    ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
 
-    if (!v_trans) {
-        v_view = ggml_view_1d(ctx, v,
-                n_tokens*n_embd_v_gqa,
-                ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
-    } else {
-        v_cur = ggml_transpose(ctx, v_cur);
+    v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
 
-        v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
-                (v->ne[1]    )*ggml_element_size(v),
-                (sinfo.head())*ggml_element_size(v));
-    }
-
-    return ggml_cpy(ctx, v_cur, v_view);
+    return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
 }
 
 ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
@@ -1143,10 +1091,6 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama
 }
 
 void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
-    if (!supports_set_rows) {
-        return;
-    }
-
     const uint32_t n_tokens = ubatch->n_tokens;
     GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
 
@@ -1163,10 +1107,6 @@ void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ub
 }
 
 void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
-    if (!supports_set_rows) {
-        return;
-    }
-
     const uint32_t n_tokens = ubatch->n_tokens;
     GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
 
@@ -2004,10 +1944,6 @@ uint32_t llama_kv_cache_context::get_n_kv() const {
     return n_kv;
 }
 
-bool llama_kv_cache_context::get_supports_set_rows() const {
-    return kv->get_supports_set_rows();
-}
-
 ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
     return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
 }
index 3ca82917d3237d25c7a138263fb2197ffc0f9e76..07d29bb8187e27d1e20a3c36c1bcb2a142591398 100644 (file)
@@ -141,9 +141,6 @@ public:
 
     uint32_t get_n_kv(const slot_info & sinfo) const;
 
-    // TODO: temporary
-    bool get_supports_set_rows() const;
-
     // get views of the current state of the cache
     ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
     ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
@@ -215,10 +212,6 @@ private:
     // env: LLAMA_KV_CACHE_DEBUG
     int debug = 0;
 
-    // env: LLAMA_SET_ROWS (temporary)
-    // ref: https://github.com/ggml-org/llama.cpp/pull/14285
-    bool supports_set_rows = true;
-
     const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
 
     std::vector<ggml_context_ptr>        ctxs;
@@ -318,9 +311,6 @@ public:
 
     uint32_t get_n_kv() const;
 
-    // TODO: temporary
-    bool get_supports_set_rows() const;
-
     // get views of the current state of the cache
     ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
     ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;