]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add function for model-based max number of graph nodes (#8622)
authorGeorgi Gerganov <redacted>
Sat, 27 Jul 2024 11:59:29 +0000 (14:59 +0300)
committerGitHub <redacted>
Sat, 27 Jul 2024 11:59:29 +0000 (14:59 +0300)
* llama : model-based max number of graph nodes

ggml-ci

* llama : disable 405B max_nodes path due to lack of complaints

ggml-ci

src/llama.cpp

index bc830c0ef8da48c9e41c24615d6cf9ac8f215d3d..c9cdbb34344d5afc50a685489638600ab6e2d58d 100644 (file)
 #endif
 
 // bump if necessary
-#define LLAMA_MAX_NODES   8192
 #define LLAMA_MAX_LAYERS  512
 #define LLAMA_MAX_EXPERTS 160  // DeepSeekV2
 
@@ -3567,6 +3566,15 @@ namespace GGUFMeta {
 
 using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;
 
+// TODO: update when needed or think of some clever automatic way to do this
+static size_t llama_model_max_nodes(const llama_model & /*model*/) {
+    //if (model.arch == LLM_ARCH_LLAMA && model.hparams.n_layer > ??) { // llama-3 405B
+    //    return 32768;
+    //}
+
+    return 8192;
+}
+
 struct llama_model_loader {
     int n_kv      = 0;
     int n_tensors = 0;
@@ -8396,7 +8404,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_k_shift() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         GGML_ASSERT(kv_self.size == n_ctx);
 
@@ -8427,7 +8435,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_s_copy() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         GGML_ASSERT(kv_self.recurrent);
 
@@ -8450,7 +8458,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         for (uint32_t i = 0; i < ids.size(); ++i) {
             const uint32_t id = ids[i];
@@ -8691,7 +8699,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_llama() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8834,7 +8842,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_baichuan() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8949,7 +8957,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_xverse() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9052,7 +9060,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_falcon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9172,7 +9180,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_grok() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9329,7 +9337,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_dbrx() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9455,7 +9463,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_starcoder() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9559,7 +9567,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_refact() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9653,7 +9661,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bert() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9847,7 +9855,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bloom() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9948,7 +9956,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mpt() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -10238,7 +10246,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10350,7 +10358,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10462,7 +10470,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2moe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -10608,7 +10616,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -10729,7 +10737,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -10961,7 +10969,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gpt2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -11066,7 +11074,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_codeshell() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -11177,7 +11185,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_orion() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11295,7 +11303,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_internlm2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11416,7 +11424,7 @@ struct llm_build_context {
     //      https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738
     // based on the original build_llama() function
     struct ggml_cgraph * build_minicpm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11560,7 +11568,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
@@ -11668,7 +11676,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
@@ -11803,7 +11811,7 @@ struct llm_build_context {
 
 
     struct ggml_cgraph * build_starcoder2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11922,7 +11930,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mamba() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t d_model = n_embd;
         const int64_t d_conv  = hparams.ssm_d_conv;
@@ -12071,7 +12079,7 @@ struct llm_build_context {
 
     struct ggml_cgraph * build_command_r() {
 
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12225,7 +12233,7 @@ struct llm_build_context {
     //   * removed bias
     //   * removed MoE
     struct ggml_cgraph * build_olmo() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -12349,7 +12357,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_openelm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12474,7 +12482,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gptneox() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -12616,7 +12624,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_arctic() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -12748,7 +12756,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_deepseek2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -12976,7 +12984,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bitnet() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -13116,7 +13124,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_t5() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -13433,7 +13441,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_jais() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -13525,7 +13533,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_chatglm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -14870,9 +14878,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     // each move requires 6*n_layer tensors (see build_defrag)
     //   - source view, destination view, copy operation
     //   - x2 for keys and values
-    //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
+    //const uint32_t max_moves = llama_model_max_nodes(model)/(6*n_layer);
     // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer);
+    const uint32_t max_moves = (llama_model_max_nodes(lctx.model) - 2*n_layer)/(6*n_layer);
 
     // determine which KV cells to move where
     //
@@ -16762,8 +16770,10 @@ struct llama_context * llama_new_context_with_model(
                 }
             }
 
+            const size_t max_nodes = llama_model_max_nodes(*model);
+
             // buffer used to store the computation graph and the tensor meta data
-            ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false));
+            ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
 
             // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
             bool pipeline_parallel =
@@ -16776,7 +16786,7 @@ struct llama_context * llama_new_context_with_model(
             // currently this is only implemented in the CUDA backend
             pipeline_parallel = false;
 #endif
-            ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES, pipeline_parallel);
+            ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), max_nodes, pipeline_parallel);
 
             if (pipeline_parallel) {
                 LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched));