}
}
- // resolve automatic Flash Attention use and reserve worst-case graph
if (!hparams.vocab_only) {
- const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-
- LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
-
- int n_splits_pp = -1;
- int n_nodes_pp = -1;
-
- int n_splits_tg = -1;
- int n_nodes_tg = -1;
-
llama_memory_context_ptr mctx;
if (memory) {
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
}
cross.v_embd.clear();
+ // resolve automatic Flash Attention use
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
+ auto * gf = graph_reserve(1, 1, 0, mctx.get(), true);
+ if (!gf) {
+ throw std::runtime_error("failed to split graph for Flash Attention check");
+ }
+
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
+ bool fa_device_mismatch = false;
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
+ ggml_tensor * n = ggml_graph_node(gf, i);
+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
+ continue;
+ }
+ ggml_backend_dev_t device_fa = ggml_backend_get_device(
+ ggml_backend_sched_get_tensor_backend(sched.get(), n));
+
+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
+ const int il = std::stoi(n->name + prefix_len);
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
+ if (device_fa != device_kv) {
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
+ "is assigned to device %s (usually due to missing support)\n",
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
+ fa_device_mismatch = true;
+ break;
+ }
+ }
+ if (fa_device_mismatch) {
+ cparams.flash_attn = false;
+ LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
+ if (ggml_is_quantized(params.type_v)) {
+ throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
+ }
+ } else {
+ cparams.flash_attn = true;
+ LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
+ }
+ }
+
+ // reserve worst-case graph
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+
+ LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
+
+ int n_splits_pp = -1;
+ int n_nodes_pp = -1;
+
+ int n_splits_tg = -1;
+ int n_nodes_tg = -1;
// reserve pp (prompt processing) graph first so that buffers are only allocated once
{
throw std::runtime_error("failed to allocate compute pp buffers");
}
- if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
- ggml_backend_sched_alloc_graph(sched.get(), gf);
-
- const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
- bool fa_device_mismatch = false;
- for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
- ggml_tensor * n = ggml_graph_node(gf, i);
- if (n->op != GGML_OP_FLASH_ATTN_EXT) {
- continue;
- }
- ggml_backend_dev_t device_fa = ggml_backend_get_device(
- ggml_backend_sched_get_tensor_backend(sched.get(), n));
-
- // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
- GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
- const int il = std::stoi(n->name + prefix_len);
- ggml_backend_dev_t device_kv = model.dev_layer(il);
- if (device_fa != device_kv) {
- LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
- "is assigned to device %s (usually due to missing support)\n",
- __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
- // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
- fa_device_mismatch = true;
- break;
- }
- }
- if (fa_device_mismatch) {
- cparams.flash_attn = false;
- LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
- if (ggml_is_quantized(params.type_v)) {
- throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
- }
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
- if (!gf) {
- throw std::runtime_error("failed to allocate compute pp buffers");
- }
- } else {
- cparams.flash_attn = true;
- LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
- }
- }
-
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
n_nodes_pp = ggml_graph_n_nodes(gf);
}
return static_cast<llm_graph_result *>(gf_res_reserve.get());
}
-ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
+ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
if (n_tokens % n_seqs != 0) {
this->n_outputs = save_n_outputs;
// initialize scheduler with the specified graph
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+ if (split_only) {
+ ggml_backend_sched_split_graph(sched.get(), gf);
+ } else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
return nullptr;
}