]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama: use FA + max. GPU layers by default (#15434)
authorJohannes Gäßler <redacted>
Sat, 30 Aug 2025 14:32:10 +0000 (16:32 +0200)
committerGitHub <redacted>
Sat, 30 Aug 2025 14:32:10 +0000 (16:32 +0200)
* llama: use max. GPU layers by default, auto -fa

* ggml-backend: abort instead of segfault

19 files changed:
common/arg.cpp
common/common.cpp
common/common.h
examples/diffusion/diffusion-cli.cpp
ggml/src/ggml-backend.cpp
include/llama.h
scripts/server-bench.py
scripts/tool_bench.py
src/llama-context.cpp
src/llama-graph.cpp
src/llama-graph.h
src/llama-impl.h
src/llama-model.cpp
src/llama.cpp
tools/batched-bench/batched-bench.cpp
tools/llama-bench/llama-bench.cpp
tools/server/tests/unit/test_ctx_shift.py
tools/server/tests/unit/test_speculative.py
tools/server/tests/utils.py

index 93f0108b2b93bc646fb26822a65cd2459d73c9b7..72c69c39a0fe199bdedc89fbe2ef184ae6a83e45 100644 (file)
@@ -1545,10 +1545,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         }
     ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
     add_opt(common_arg(
-        {"-fa", "--flash-attn"},
-        string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
-        [](common_params & params) {
-            params.flash_attn = true;
+        {"-fa", "--flash-attn"}, "FA",
+        string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", llama_flash_attn_type_name(params.flash_attn_type)),
+        [](common_params & params, const std::string & value) {
+            if (value == "on" || value == "enabled") {
+                params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
+            } else if (value == "off" || value == "disabled") {
+                params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
+            } else if (value == "auto") {
+                params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
+            } else {
+                throw std::runtime_error(string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
+            }
         }
     ).set_env("LLAMA_ARG_FLASH_ATTN"));
     add_opt(common_arg(
@@ -3459,8 +3467,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF";
             params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf";
             params.port = 8012;
-            params.n_gpu_layers = 99;
-            params.flash_attn = true;
             params.n_ubatch = 1024;
             params.n_batch = 1024;
             params.n_ctx = 0;
@@ -3475,8 +3481,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF";
             params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf";
             params.port = 8012;
-            params.n_gpu_layers = 99;
-            params.flash_attn = true;
             params.n_ubatch = 1024;
             params.n_batch = 1024;
             params.n_ctx = 0;
@@ -3491,8 +3495,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
             params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
             params.port = 8012;
-            params.n_gpu_layers = 99;
-            params.flash_attn = true;
             params.n_ubatch = 1024;
             params.n_batch = 1024;
             params.n_ctx = 0;
@@ -3508,10 +3510,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
             params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
             params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
-            params.speculative.n_gpu_layers = 99;
             params.port = 8012;
-            params.n_gpu_layers = 99;
-            params.flash_attn = true;
             params.n_ubatch = 1024;
             params.n_batch = 1024;
             params.n_ctx = 0;
@@ -3527,10 +3526,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
             params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
             params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
-            params.speculative.n_gpu_layers = 99;
             params.port = 8012;
-            params.n_gpu_layers = 99;
-            params.flash_attn = true;
             params.n_ubatch = 1024;
             params.n_batch = 1024;
             params.n_ctx = 0;
@@ -3545,8 +3541,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.model.hf_repo = "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF";
             params.model.hf_file = "qwen3-coder-30b-a3b-instruct-q8_0.gguf";
             params.port = 8012;
-            params.n_gpu_layers = 99;
-            params.flash_attn = true;
             params.n_ubatch = 1024;
             params.n_batch = 1024;
             params.n_ctx = 0;
index 054b43be770da67bc7442ebbc5cb816ce4e52cb7..0c92d4d57ddbf7c6aa64ae73fcae2c8dad3bc98b 100644 (file)
@@ -901,7 +901,8 @@ struct common_init_result common_init_from_params(common_params & params) {
 
     llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
     if (model == NULL) {
-        LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
+        LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
+            __func__, params.model.path.c_str());
         return iparams;
     }
 
@@ -911,7 +912,8 @@ struct common_init_result common_init_from_params(common_params & params) {
 
     llama_context * lctx = llama_init_from_model(model, cparams);
     if (lctx == NULL) {
-        LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
+        LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
+            __func__, params.model.path.c_str());
         llama_model_free(model);
         return iparams;
     }
@@ -1157,10 +1159,10 @@ struct llama_context_params common_context_params_to_llama(const common_params &
     cparams.yarn_orig_ctx     = params.yarn_orig_ctx;
     cparams.pooling_type      = params.pooling_type;
     cparams.attention_type    = params.attention_type;
+    cparams.flash_attn_type   = params.flash_attn_type;
     cparams.cb_eval           = params.cb_eval;
     cparams.cb_eval_user_data = params.cb_eval_user_data;
     cparams.offload_kqv       = !params.no_kv_offload;
-    cparams.flash_attn        = params.flash_attn;
     cparams.no_perf           = params.no_perf;
     cparams.op_offload        = !params.no_op_offload;
     cparams.swa_full          = params.swa_full;
index 87ea0606954a359293c07668f0af1ad4d92d9e8b..02ca093bdf8b7c7269bad7524fda147f3f12fa42 100644 (file)
@@ -312,6 +312,7 @@ struct common_params {
     enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
     enum llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
     enum llama_attention_type    attention_type    = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
+    enum llama_flash_attn_type   flash_attn_type   = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention
 
     struct common_params_sampling    sampling;
     struct common_params_speculative speculative;
@@ -375,7 +376,6 @@ struct common_params {
     bool multiline_input   = false; // reverse the usage of `\`
     bool simple_io         = false; // improves compatibility with subprocesses and limited consoles
     bool cont_batching     = true;  // insert new sequences for decoding on-the-fly
-    bool flash_attn        = false; // flash attention
     bool no_perf           = false; // disable performance metrics
     bool ctx_shift         = false;  // context shift on infinite text generation
     bool swa_full          = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
index 8431dcea8fe2af1c3339f0eb1d89e060a66fed41..abf7fb357326bcdef03b626c2d9382f1e5ad57b9 100644 (file)
@@ -564,7 +564,7 @@ int main(int argc, char ** argv) {
     ctx_params.n_ctx                = params.n_ctx;
     ctx_params.n_batch              = params.n_batch;
     ctx_params.n_ubatch             = params.n_ubatch;
-    ctx_params.flash_attn           = params.flash_attn;
+    ctx_params.flash_attn_type      = params.flash_attn_type;
     ctx_params.no_perf              = params.no_perf;
     ctx_params.type_k               = params.cache_type_k;
     ctx_params.type_v               = params.cache_type_v;
index e34feccc98a5ec16c1dcb1353b0d04a6e81051e0..02375337c4dd68076b68d79941a0dd6d90a446b5 100644 (file)
@@ -31,6 +31,7 @@
 // backend buffer type
 
 const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(buft);
     return buft->iface.get_name(buft);
 }
 
@@ -40,14 +41,17 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t
         return ggml_backend_buffer_init(buft, {}, NULL, 0);
     }
 
+    GGML_ASSERT(buft);
     return buft->iface.alloc_buffer(buft, size);
 }
 
 size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(buft);
     return buft->iface.get_alignment(buft);
 }
 
 size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(buft);
     // get_max_size is optional, defaults to SIZE_MAX
     if (buft->iface.get_max_size) {
         return buft->iface.get_max_size(buft);
@@ -56,6 +60,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
 }
 
 size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
+    GGML_ASSERT(buft);
     // get_alloc_size is optional, defaults to ggml_nbytes
     if (buft->iface.get_alloc_size) {
         size_t size = buft->iface.get_alloc_size(buft, tensor);
@@ -66,6 +71,7 @@ size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const s
 }
 
 bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(buft);
     if (buft->iface.is_host) {
         return buft->iface.is_host(buft);
     }
@@ -73,6 +79,7 @@ bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
 }
 
 ggml_backend_dev_t ggml_backend_buft_get_device(ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(buft);
     return buft->device;
 }
 
@@ -110,10 +117,12 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
 }
 
 size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
+    GGML_ASSERT(buffer);
     return buffer->size;
 }
 
 void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
+    GGML_ASSERT(buffer);
     // get_base is optional if the buffer is zero-sized
     if (buffer->size == 0) {
         return NULL;
@@ -127,6 +136,7 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
 }
 
 enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    GGML_ASSERT(buffer);
     // init_tensor is optional
     if (buffer->iface.init_tensor) {
         return buffer->iface.init_tensor(buffer, tensor);
@@ -135,6 +145,7 @@ enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, s
 }
 
 void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    GGML_ASSERT(buffer);
     // clear is optional if the buffer is zero-sized
     if (buffer->size == 0) {
         return;
@@ -160,6 +171,7 @@ bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
 }
 
 void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
+    GGML_ASSERT(buffer);
     buffer->usage = usage;
 
     // FIXME: add a generic callback to the buffer interface
@@ -169,14 +181,17 @@ void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backe
 }
 
 enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) {
+    GGML_ASSERT(buffer);
     return buffer->usage;
 }
 
 ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) {
+    GGML_ASSERT(buffer);
     return buffer->buft;
 }
 
 void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {
+    GGML_ASSERT(buffer);
     if (buffer->iface.reset) {
         buffer->iface.reset(buffer);
     }
@@ -215,6 +230,7 @@ void ggml_backend_free(ggml_backend_t backend) {
 }
 
 ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
+    GGML_ASSERT(backend);
     return ggml_backend_dev_buffer_type(backend->device);
 }
 
@@ -231,6 +247,8 @@ size_t ggml_backend_get_max_size(ggml_backend_t backend) {
 }
 
 void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(backend);
+    GGML_ASSERT(tensor);
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
     GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
 
@@ -242,6 +260,8 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor *
 }
 
 void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(backend);
+    GGML_ASSERT(tensor);
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
     GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
 
@@ -283,6 +303,7 @@ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, siz
 }
 
 void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    GGML_ASSERT(tensor);
     ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
 
     if (size == 0) {
@@ -298,6 +319,7 @@ void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size
 }
 
 void ggml_backend_synchronize(ggml_backend_t backend) {
+    GGML_ASSERT(backend);
     if (backend->iface.synchronize == NULL) {
         return;
     }
@@ -306,18 +328,21 @@ void ggml_backend_synchronize(ggml_backend_t backend) {
 }
 
 ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    GGML_ASSERT(backend);
     GGML_ASSERT(backend->iface.graph_plan_create != NULL);
 
     return backend->iface.graph_plan_create(backend, cgraph);
 }
 
 void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    GGML_ASSERT(backend);
     GGML_ASSERT(backend->iface.graph_plan_free != NULL);
 
     backend->iface.graph_plan_free(backend, plan);
 }
 
 enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    GGML_ASSERT(backend);
     GGML_ASSERT(backend->iface.graph_plan_compute != NULL);
 
     return backend->iface.graph_plan_compute(backend, plan);
@@ -330,22 +355,27 @@ enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_
 }
 
 enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    GGML_ASSERT(backend);
     return backend->iface.graph_compute(backend, cgraph);
 }
 
 bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    GGML_ASSERT(backend);
     return ggml_backend_dev_supports_op(backend->device, op);
 }
 
 bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(backend);
     return ggml_backend_dev_supports_buft(backend->device, buft);
 }
 
 bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    GGML_ASSERT(backend);
     return ggml_backend_dev_offload_op(backend->device, op);
 }
 
 ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
+    GGML_ASSERT(backend);
     return backend->device;
 }
 
@@ -381,6 +411,7 @@ void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t b
         return;
     }
 
+    GGML_ASSERT(backend_dst);
     if (backend_dst->iface.cpy_tensor_async != NULL) {
         if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) {
             return;
@@ -412,18 +443,21 @@ void ggml_backend_event_free(ggml_backend_event_t event) {
 }
 
 void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend) {
+    GGML_ASSERT(backend);
     GGML_ASSERT(backend->iface.event_record != NULL);
 
     backend->iface.event_record(backend, event);
 }
 
 void ggml_backend_event_synchronize(ggml_backend_event_t event) {
+    GGML_ASSERT(event);
     GGML_ASSERT(event->device->iface.event_synchronize);
 
     event->device->iface.event_synchronize(event->device, event);
 }
 
 void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+    GGML_ASSERT(backend);
     GGML_ASSERT(backend->iface.event_wait != NULL);
 
     backend->iface.event_wait(backend, event);
@@ -432,18 +466,22 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
 // Backend device
 
 const char * ggml_backend_dev_name(ggml_backend_dev_t device) {
+    GGML_ASSERT(device);
     return device->iface.get_name(device);
 }
 
 const char * ggml_backend_dev_description(ggml_backend_dev_t device) {
+    GGML_ASSERT(device);
     return device->iface.get_description(device);
 }
 
 void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
+    GGML_ASSERT(device);
     device->iface.get_memory(device, free, total);
 }
 
 enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {
+    GGML_ASSERT(device);
     return device->iface.get_type(device);
 }
 
@@ -453,18 +491,22 @@ void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_d
 }
 
 ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device) {
+    GGML_ASSERT(device);
     return device->reg;
 }
 
 ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params) {
+    GGML_ASSERT(device);
     return device->iface.init_backend(device, params);
 }
 
 ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
+    GGML_ASSERT(device);
     return device->iface.get_buffer_type(device);
 }
 
 ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
+    GGML_ASSERT(device);
     if (device->iface.get_host_buffer_type == NULL) {
         return NULL;
     }
@@ -473,18 +515,22 @@ ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t
 }
 
 ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) {
+    GGML_ASSERT(device);
     return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size);
 }
 
 bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
+    GGML_ASSERT(device);
     return device->iface.supports_op(device, op);
 }
 
 bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(device);
     return device->iface.supports_buft(device, buft);
 }
 
 bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
+    GGML_ASSERT(device);
     if (device->iface.offload_op != NULL) {
         return device->iface.offload_op(device, op);
     }
@@ -495,18 +541,22 @@ bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_te
 // Backend (reg)
 
 const char * ggml_backend_reg_name(ggml_backend_reg_t reg) {
+    GGML_ASSERT(reg);
     return reg->iface.get_name(reg);
 }
 
 size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg) {
+    GGML_ASSERT(reg);
     return reg->iface.get_device_count(reg);
 }
 
 ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index) {
+    GGML_ASSERT(reg);
     return reg->iface.get_device(reg, index);
 }
 
 void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    GGML_ASSERT(reg);
     if (!reg->iface.get_proc_address) {
         return NULL;
     }
@@ -521,6 +571,7 @@ struct ggml_backend_multi_buffer_context {
 };
 
 static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    GGML_ASSERT(buffer);
     ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
     for (size_t i = 0; i < ctx->n_buffers; i++) {
         ggml_backend_buffer_free(ctx->buffers[i]);
@@ -531,6 +582,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer)
 }
 
 static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    GGML_ASSERT(buffer);
     ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
     for (size_t i = 0; i < ctx->n_buffers; i++) {
         ggml_backend_buffer_clear(ctx->buffers[i], value);
@@ -566,10 +618,12 @@ ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer
 }
 
 bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) {
+    GGML_ASSERT(buffer);
     return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer;
 }
 
 void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
+    GGML_ASSERT(buffer);
     GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer));
     ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
     for (size_t i = 0; i < ctx->n_buffers; i++) {
@@ -1349,6 +1403,7 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
 }
 
 static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
+    GGML_ASSERT(sched);
     struct ggml_backend_sched_split * splits = sched->splits;
 
     ggml_tensor * prev_ids_tensor = nullptr;
@@ -1617,6 +1672,7 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
 }
 
 void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
+    GGML_ASSERT(sched);
     // reset state for the next run
     if (!sched->is_reset) {
         ggml_hash_set_reset(&sched->hash_set);
@@ -1628,6 +1684,7 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
 }
 
 bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
+    GGML_ASSERT(sched);
     GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
 
     ggml_backend_sched_synchronize(sched);
@@ -1644,6 +1701,7 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
 }
 
 bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    GGML_ASSERT(sched);
     GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
     GGML_ASSERT(!sched->is_alloc);
 
@@ -1668,6 +1726,7 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st
 }
 
 enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    GGML_ASSERT(sched);
     if (!sched->is_reset && !sched->is_alloc) {
         ggml_backend_sched_reset(sched);
     }
@@ -1682,6 +1741,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch
 }
 
 void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
+    GGML_ASSERT(sched);
     for (int i = 0; i < sched->n_backends; i++) {
         ggml_backend_synchronize(sched->backends[i]);
     }
@@ -1694,28 +1754,34 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
 }
 
 void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
+    GGML_ASSERT(sched);
     sched->callback_eval = callback;
     sched->callback_eval_user_data = user_data;
 }
 
 int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
+    GGML_ASSERT(sched);
     return sched->n_splits;
 }
 
 int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
+    GGML_ASSERT(sched);
     return sched->n_copies;
 }
 
 int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
+    GGML_ASSERT(sched);
     return sched->n_backends;
 }
 
 ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
+    GGML_ASSERT(sched);
     GGML_ASSERT(i >= 0 && i < sched->n_backends);
     return sched->backends[i];
 }
 
 size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
+    GGML_ASSERT(sched);
     int backend_index = ggml_backend_sched_backend_id(sched, backend);
     GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
 
@@ -1723,6 +1789,7 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
 }
 
 void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
+    GGML_ASSERT(sched);
     int backend_index = ggml_backend_sched_backend_id(sched, backend);
     GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
     tensor_backend_id(node) = backend_index;
@@ -1731,6 +1798,7 @@ void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct gg
 }
 
 ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
+    GGML_ASSERT(sched);
     int backend_index = tensor_backend_id(node);
     if (backend_index == -1) {
         return NULL;
@@ -1741,6 +1809,7 @@ ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched,
 // utils
 
 enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor);
     GGML_ASSERT(tensor->buffer == NULL);
     GGML_ASSERT(tensor->view_src != NULL);
     GGML_ASSERT(tensor->view_src->buffer != NULL);
@@ -1752,6 +1821,7 @@ enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) {
 }
 
 enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
+    GGML_ASSERT(tensor);
     GGML_ASSERT(tensor->buffer == NULL);
     GGML_ASSERT(tensor->data == NULL);
     GGML_ASSERT(tensor->view_src == NULL);
@@ -1825,6 +1895,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_
 }
 
 struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
+    GGML_ASSERT(graph);
     struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size);
     struct ggml_tensor ** node_copies = (ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT
     bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0]));
@@ -1969,6 +2040,7 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
 // CPU backend - buffer
 
 static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
+    GGML_ASSERT(buffer);
     uintptr_t data = (uintptr_t)buffer->context;
 
     // align the buffer
@@ -1980,28 +2052,33 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
 }
 
 static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    GGML_ASSERT(buffer);
     ggml_aligned_free(buffer->context, buffer->size);
 }
 
 static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    GGML_ASSERT(tensor);
     memset((char *)tensor->data + offset, value, size);
 
     GGML_UNUSED(buffer);
 }
 
 static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor);
     memcpy((char *)tensor->data + offset, data, size);
 
     GGML_UNUSED(buffer);
 }
 
 static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor);
     memcpy(data, (const char *)tensor->data + offset, size);
 
     GGML_UNUSED(buffer);
 }
 
 static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+    GGML_ASSERT(src);
     if (ggml_backend_buffer_is_host(src->buffer)) {
         memcpy(dst->data, src->data, ggml_nbytes(src));
         return true;
@@ -2012,6 +2089,7 @@ static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
 }
 
 static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    GGML_ASSERT(buffer);
     memset(buffer->context, value, buffer->size);
 }
 
index 702535385057fcb2529cbfe05363cdee1c85919a..346135c71e2e6687089a6048c261bc8800a62932 100644 (file)
@@ -179,6 +179,14 @@ extern "C" {
         LLAMA_ATTENTION_TYPE_NON_CAUSAL  = 1,
     };
 
+    enum llama_flash_attn_type {
+        LLAMA_FLASH_ATTN_TYPE_AUTO     = -1,
+        LLAMA_FLASH_ATTN_TYPE_DISABLED = 0,
+        LLAMA_FLASH_ATTN_TYPE_ENABLED  = 1,
+    };
+
+    LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type);
+
     enum llama_split_mode {
         LLAMA_SPLIT_MODE_NONE  = 0, // single GPU
         LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
@@ -303,6 +311,7 @@ extern "C" {
         enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
         enum llama_pooling_type      pooling_type;      // whether to pool (sum) embedding results by sequence id
         enum llama_attention_type    attention_type;    // attention type to use for embeddings
+        enum llama_flash_attn_type   flash_attn_type;   // when to enable Flash Attention
 
         // ref: https://github.com/ggml-org/llama.cpp/pull/2054
         float    rope_freq_base;   // RoPE base frequency, 0 = from model
@@ -329,7 +338,6 @@ extern "C" {
         // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
         bool embeddings;  // if true, extract embeddings (together with logits)
         bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
-        bool flash_attn;  // use flash attention [EXPERIMENTAL]
         bool no_perf;     // measure performance timings
         bool op_offload;  // offload host tensor operations to device
         bool swa_full;    // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
index a71602017340afd65adb9541c3bf2ca3750b85f9..dbbb0939ffef9058b7e9df51a0a3bd378811bda9 100755 (executable)
@@ -151,12 +151,6 @@ def benchmark(
     if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
         logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
         os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
-    if not external_server and os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
-        logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
-        os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
-    if not external_server and os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
-        logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
-        os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
 
     parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
     prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
index d8018e2e23c0dfe6360a08c4b36c01cbd65a1790..05d6dfc30a36ddd190473e8b46d5fb053fb69abe 100755 (executable)
@@ -323,7 +323,7 @@ def run(
                     server.jinja = True
                     server.ctk = ctk
                     server.ctv = ctv
-                    server.fa = fa
+                    server.fa = "on" if fa else "off"
                     server.n_predict = n_predict
                     server.model_hf_repo = hf
                     server.model_hf_file = None
index 6b20161a389a0f8a18178dcebb975e1274ebb7d5..ac8453ab741d4a29272007c96e83658b1fdefc06 100644 (file)
@@ -41,7 +41,6 @@ llama_context::llama_context(
     cparams.yarn_beta_slow   = params.yarn_beta_slow;
     cparams.embeddings       = params.embeddings;
     cparams.offload_kqv      = params.offload_kqv;
-    cparams.flash_attn       = params.flash_attn;
     cparams.no_perf          = params.no_perf;
     cparams.pooling_type     = params.pooling_type;
     cparams.warmup           = false;
@@ -86,6 +85,8 @@ llama_context::llama_context(
         cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
     }
 
+    cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
+
     // with causal attention, the batch size is limited by the context size
     cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
 
@@ -119,7 +120,7 @@ llama_context::llama_context(
     LLAMA_LOG_INFO("%s: n_batch       = %u\n",   __func__, cparams.n_batch);
     LLAMA_LOG_INFO("%s: n_ubatch      = %u\n",   __func__, cparams.n_ubatch);
     LLAMA_LOG_INFO("%s: causal_attn   = %d\n",   __func__, cparams.causal_attn);
-    LLAMA_LOG_INFO("%s: flash_attn    = %d\n",   __func__, cparams.flash_attn);
+    LLAMA_LOG_INFO("%s: flash_attn    = %s\n",   __func__, llama_flash_attn_type_name(params.flash_attn_type));
     LLAMA_LOG_INFO("%s: kv_unified    = %s\n",   __func__, cparams.kv_unified ? "true" : "false");
     LLAMA_LOG_INFO("%s: freq_base     = %.1f\n", __func__, cparams.rope_freq_base);
     LLAMA_LOG_INFO("%s: freq_scale    = %g\n",   __func__, cparams.rope_freq_scale);
@@ -269,7 +270,7 @@ llama_context::llama_context(
         }
     }
 
-    // reserve worst-case graph
+    // resolve automatic Flash Attention use and reserve worst-case graph
     if (!hparams.vocab_only) {
         const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
@@ -300,6 +301,48 @@ llama_context::llama_context(
                 throw std::runtime_error("failed to allocate compute pp buffers");
             }
 
+            if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
+                ggml_backend_sched_alloc_graph(sched.get(), gf);
+
+                const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
+                bool fa_device_mismatch = false;
+                for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
+                    ggml_tensor * n = ggml_graph_node(gf, i);
+                    if (n->op != GGML_OP_FLASH_ATTN_EXT) {
+                        continue;
+                    }
+                    ggml_backend_dev_t device_fa = ggml_backend_get_device(
+                        ggml_backend_sched_get_tensor_backend(sched.get(), n));
+
+                    // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
+                    GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
+                    const int il = std::stoi(n->name + prefix_len);
+                    ggml_backend_dev_t device_kv = model.dev_layer(il);
+                    if (device_fa != device_kv) {
+                        LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
+                            "is assigned to device %s (usually due to missing support)\n",
+                            __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
+                        // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
+                        fa_device_mismatch = true;
+                        break;
+                    }
+                }
+                if (fa_device_mismatch) {
+                    cparams.flash_attn = false;
+                    LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
+                    if (ggml_is_quantized(params.type_v)) {
+                        throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
+                    }
+                    auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
+                    if (!gf) {
+                        throw std::runtime_error("failed to allocate compute pp buffers");
+                    }
+                } else {
+                    cparams.flash_attn = true;
+                    LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
+                }
+            }
+
             n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
             n_nodes_pp  = ggml_graph_n_nodes(gf);
         }
@@ -2208,6 +2251,7 @@ llama_context_params llama_context_default_params() {
         /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
         /*.pooling_type                =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
         /*.attention_type              =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
+        /*.flash_attn_type             =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
         /*.rope_freq_base              =*/ 0.0f,
         /*.rope_freq_scale             =*/ 0.0f,
         /*.yarn_ext_factor             =*/ -1.0f,
@@ -2224,7 +2268,6 @@ llama_context_params llama_context_default_params() {
         /*.abort_callback_data         =*/ nullptr,
         /*.embeddings                  =*/ false,
         /*.offload_kqv                 =*/ true,
-        /*.flash_attn                  =*/ false,
         /*.no_perf                     =*/ true,
         /*.op_offload                  =*/ true,
         /*.swa_full                    =*/ true,
@@ -2252,12 +2295,30 @@ llama_context * llama_init_from_model(
         return nullptr;
     }
 
-    if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
+    if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
         LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
-        params.flash_attn = false;
+        params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
+    }
+
+    if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
+        const uint32_t blck_size = ggml_blck_size(params.type_k);
+        if (model->hparams.n_embd_head_k % blck_size != 0) {
+            LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
+                __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
+            return nullptr;
+        }
+    }
+
+    if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
+        const uint32_t blck_size = ggml_blck_size(params.type_v);
+        if (model->hparams.n_embd_head_v % blck_size != 0) {
+            LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
+                __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
+            return nullptr;
+        }
     }
 
-    if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
+    if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) {
         LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
         return nullptr;
     }
index 1f2fc3ab62d4e630d54dd45e3ef23662c45b191f..49ea5da7cb42219f92d764fddfe36e6d293b7dfb 100644 (file)
@@ -1221,7 +1221,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
          ggml_tensor * kq_mask,
          ggml_tensor * sinks,
          ggml_tensor * v_mla,
-             float     kq_scale) const {
+               float   kq_scale,
+                 int   il) const {
     const bool v_trans = v->nb[1] > v->nb[2];
 
     // split the batch into streams if needed
@@ -1256,6 +1257,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 
         cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
+        cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
 
         ggml_flash_attn_ext_add_sinks(cur, sinks);
         ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
@@ -1271,6 +1273,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
             // The permutations are noops and only change how the tensor data is interpreted.
             cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
             cur = ggml_mul_mat(ctx0, v_mla, cur);
+            cb(cur, "fattn_mla", il);
             cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
             cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
 #endif
@@ -1279,6 +1282,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
         cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
     } else {
         ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+        cb(kq, "kq", il);
 
         // note: this op tends to require high floating point range
         //       while for some models F16 is enough, for others it is not, so we default to F32 here
@@ -1292,32 +1296,42 @@ ggml_tensor * llm_graph_context::build_attn_mha(
             // before the softmax below
 
             kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
+            cb(kq, "kq_tanh", il);
             kq = ggml_scale(ctx0, kq, 30);
+            cb(kq, "kq_scaled", il);
         }
 
         if (hparams.attn_soft_cap) {
             kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
+            cb(kq, "kq_scaled_1", il);
             kq = ggml_tanh (ctx0, kq);
+            cb(kq, "kq_tanh", il);
             kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
+            cb(kq, "kq_scaled_2", il);
         }
 
         if (kq_b) {
             kq = ggml_add(ctx0, kq, kq_b);
+            cb(kq, "kq_plus_kq_b", il);
         }
 
         kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
         ggml_soft_max_add_sinks(kq, sinks);
+        cb(kq, "kq_soft_max", il);
 
         if (!v_trans) {
             // note: avoid this branch
             v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
+            cb(v, "v_cont", il);
         }
 
         ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+        cb(kqv, "kqv", il);
 
         // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
         if (v_mla) {
             kqv = ggml_mul_mat(ctx0, v_mla, kqv);
+            cb(kqv, "kqv_mla", il);
         }
 
         cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
@@ -1378,7 +1392,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = k_cur;
     ggml_tensor * v = v_cur;
 
-    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1467,7 +1481,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = mctx_cur->get_k(ctx0, il);
     ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
-    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1534,7 +1548,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = mctx_cur->get_k(ctx0, il);
     ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
-    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1589,7 +1603,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = k_cur;
     ggml_tensor * v = v_cur;
 
-    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
     cb(cur, "kqv_out", il);
 
     if (wo) {
index e11d91d5293f09494937eb6daccf3eff5e2ecec1..3c85333fde5145a424949dda079d3d8e15d33af9 100644 (file)
@@ -687,7 +687,8 @@ struct llm_graph_context {
             ggml_tensor * kq_mask,
             ggml_tensor * sinks,   // [n_head_q]
             ggml_tensor * v_mla,   // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
-                  float   kq_scale) const;
+                  float   kq_scale,
+                    int   il) const;
 
     llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
 
index 02b1d07f8400dc3caa75f1766bf46cdc33685d65..c5163e9225a5e04ee38c426aad139bf15ddc9a90 100644 (file)
@@ -59,3 +59,5 @@ std::string llama_format_tensor_shape(const std::vector<int64_t> & ne);
 std::string llama_format_tensor_shape(const struct ggml_tensor * t);
 
 std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
+
+#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
index f3e0e9ac64b0df8342dbae4f772bfcee349c40e8..58a0581e26de0c3a444676e09a23549b7ce3d91e 100644 (file)
@@ -18994,7 +18994,7 @@ llama_model_params llama_model_default_params() {
     llama_model_params result = {
         /*.devices                     =*/ nullptr,
         /*.tensor_buft_overrides       =*/ nullptr,
-        /*.n_gpu_layers                =*/ 0,
+        /*.n_gpu_layers                =*/ 999,
         /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
         /*.main_gpu                    =*/ 0,
         /*.tensor_split                =*/ nullptr,
@@ -19008,11 +19008,6 @@ llama_model_params llama_model_default_params() {
         /*.use_extra_bufts             =*/ true,
     };
 
-#ifdef GGML_USE_METAL
-    // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
-    result.n_gpu_layers = 999;
-#endif
-
     return result;
 }
 
index 34906cdb62844875bf572a2a1df6118a2a8aa885..f0d4f5f891cc7106a01e85fc091054fc09a27f7e 100644 (file)
 // interface implementation
 //
 
+const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type) {
+    switch (flash_attn_type) {
+        case LLAMA_FLASH_ATTN_TYPE_AUTO:
+            return "auto";
+        case LLAMA_FLASH_ATTN_TYPE_DISABLED:
+            return "disabled";
+        case LLAMA_FLASH_ATTN_TYPE_ENABLED:
+            return "enabled";
+    }
+    GGML_ABORT("fatal error");
+}
+
 struct llama_sampler_chain_params llama_sampler_chain_default_params() {
     struct llama_sampler_chain_params result = {
         /*.no_perf                     =*/ true,
index 23d03039dcc036f0c28078cb5e6666d3b378c86f..46dd12caae544311631b106267481f5f9224c4bf 100644 (file)
@@ -111,7 +111,7 @@ int main(int argc, char ** argv) {
 
     if (!params.batched_bench_output_jsonl) {
         LOG("\n");
-        LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
+        LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, int(params.flash_attn_type), params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
         LOG("\n");
         LOG("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
         LOG("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");
@@ -197,7 +197,7 @@ int main(int argc, char ** argv) {
                     LOG(
                         "{\"n_kv_max\": %d, \"n_batch\": %d, \"n_ubatch\": %d, \"flash_attn\": %d, \"is_pp_shared\": %d, \"n_gpu_layers\": %d, \"n_threads\": %u, \"n_threads_batch\": %u, "
                         "\"pp\": %d, \"tg\": %d, \"pl\": %d, \"n_kv\": %d, \"t_pp\": %f, \"speed_pp\": %f, \"t_tg\": %f, \"speed_tg\": %f, \"t\": %f, \"speed\": %f}\n",
-                        n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch,
+                        n_kv_max, params.n_batch, params.n_ubatch, int(params.flash_attn_type), params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch,
                         pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed
                     );
                 } else {
index 9378706a12a7c9c31fe656137bfbc7a8fa59f7f6..9b9803dedabef8295d26aaca16a478ec77d5da0d 100644 (file)
@@ -987,16 +987,16 @@ struct cmd_params_instance {
     llama_context_params to_llama_cparams() const {
         llama_context_params cparams = llama_context_default_params();
 
-        cparams.n_ctx        = n_prompt + n_gen + n_depth;
-        cparams.n_batch      = n_batch;
-        cparams.n_ubatch     = n_ubatch;
-        cparams.type_k       = type_k;
-        cparams.type_v       = type_v;
-        cparams.offload_kqv  = !no_kv_offload;
-        cparams.flash_attn   = flash_attn;
-        cparams.embeddings   = embeddings;
-        cparams.op_offload   = !no_op_offload;
-        cparams.swa_full     = false;
+        cparams.n_ctx           = n_prompt + n_gen + n_depth;
+        cparams.n_batch         = n_batch;
+        cparams.n_ubatch        = n_ubatch;
+        cparams.type_k          = type_k;
+        cparams.type_v          = type_v;
+        cparams.offload_kqv     = !no_kv_offload;
+        cparams.flash_attn_type = flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
+        cparams.embeddings      = embeddings;
+        cparams.op_offload      = !no_op_offload;
+        cparams.swa_full        = false;
 
         return cparams;
     }
index 8f51bc301a74c44d611ce8d29a5d3b946196ac79..92e49f2bb05a48b1825796fa4d9bd14be90c61d0 100644 (file)
@@ -15,25 +15,26 @@ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deseru
 def create_server():
     global server
     server = ServerPreset.tinyllama2()
-    server.n_ctx = 256
+    server.n_ctx = 512
     server.n_slots = 2
+    server.n_predict = 128
 
 
 def test_ctx_shift_enabled():
     # the prompt is 301 tokens
-    # the slot context is 256/2 = 128 tokens
-    # the prompt is truncated to keep the last 109 tokens
-    # 64 tokens are generated thanks to shifting the context when it gets full
+    # the slot context is 512/2 = 256 tokens
+    # the prompt is truncated to keep the last (301 - 256/2) = 173 tokens
+    # 96 tokens are generated thanks to shifting the context when it gets full
     global server
     server.enable_ctx_shift = True
     server.start()
     res = server.make_request("POST", "/completion", data={
-        "n_predict": 64,
+        "n_predict": 96,
         "prompt": LONG_TEXT,
     })
     assert res.status_code == 200
-    assert res.body["timings"]["prompt_n"] == 109
-    assert res.body["timings"]["predicted_n"] == 64
+    assert res.body["timings"]["prompt_n"] == 173
+    assert res.body["timings"]["predicted_n"] == 96
     assert res.body["truncated"] is True
 
 
index 38ca4325ba675004aca3c58603beacfcc185ae5b..65952de8b8d4c51a8ba2539996a130f53d8811c3 100644 (file)
@@ -14,6 +14,7 @@ def create_server():
     server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
     server.draft_min = 4
     server.draft_max = 8
+    server.fa = "off"
 
 
 @pytest.fixture(autouse=True)
index f55a539471194e7ea34ec96ab1d498aec325ff4e..82f7215d537dba3e5b3aa249cd9d7d99fc0c4028 100644 (file)
@@ -66,7 +66,7 @@ class ServerProcess:
     n_slots: int | None = None
     ctk: str | None = None
     ctv: str | None = None
-    fa: bool | None = None
+    fa: str | None = None
     server_continuous_batching: bool | None = False
     server_embeddings: bool | None = False
     server_reranking: bool | None = False
@@ -161,7 +161,7 @@ class ServerProcess:
         if self.ctv:
             server_args.extend(["-ctv", self.ctv])
         if self.fa is not None:
-            server_args.append("-fa")
+            server_args.extend(["-fa", self.fa])
         if self.n_predict:
             server_args.extend(["--n-predict", self.n_predict])
         if self.slot_save_path:
@@ -427,7 +427,7 @@ class ServerPreset:
         server.n_batch = 300
         server.n_ubatch = 300
         server.n_slots = 2
-        server.fa = True
+        server.fa = "on"
         server.seed = 42
         server.server_embeddings = True
         return server