llama-kv-cache-iswa.cpp
llama-memory.cpp
llama-memory-hybrid.cpp
+ llama-memory-hybrid-iswa.cpp
llama-memory-recurrent.cpp
llama-mmap.cpp
llama-model-loader.cpp
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-memory-hybrid.h"
+#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
#include <cassert>
return res;
}
+void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
+ const auto * attn_ctx = mctx->get_attn();
+
+ // base tensors may not be allocated if there are no non-SWA attention layers
+ if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
+ attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
+ attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
+
+ attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
+ }
+
+ // swa tensors may not be allocated if there are no SWA attention layers
+ if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
+ attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
+ attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
+
+ attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
+ }
+
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
+
+ if (inp_rs->s_copy) {
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
+
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
+ for (uint32_t i = 0; i < n_rs; ++i) {
+ data[i] = mctx->get_recr()->s_copy(i);
+ }
+ }
+}
+
+bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
+ const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
+
+ this->mctx = mctx;
+
+ bool res = true;
+
+ const auto * attn_ctx = mctx->get_attn();
+
+ // base tensors may not be allocated if there are no non-SWA attention layers
+ if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
+ //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
+
+ res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
+ res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
+ }
+
+ // swa tensors may not be allocated if there are no SWA attention layers
+ if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
+ res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
+ //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
+
+ res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
+ res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
+ }
+
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
+
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
+
+ res &= inp_rs->head == mctx->get_recr()->get_head();
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
+
+ return res;
+}
+
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
// set the inputs only for the active samplers in the current ubatch
std::unordered_set<llama_seq_id> active_samplers;
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
+llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
+
+ auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
+
+ // build iswa attention input
+ const auto * attn_ctx = mctx_cur->get_attn();
+
+ auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
+
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
+
+ {
+ const auto n_kv = attn_ctx->get_base()->get_n_kv();
+
+ inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
+ inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
+
+ inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+ ggml_set_input(inp_attn->self_kq_mask);
+
+ inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
+ }
+
+ {
+ const auto n_kv = attn_ctx->get_swa()->get_n_kv();
+
+ inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
+ inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
+
+ inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+ ggml_set_input(inp_attn->self_kq_mask_swa);
+
+ inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
+ }
+
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
+
+ return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
+}
+
void llm_graph_context::build_dense_out(
ggml_tensor * dense_2,
ggml_tensor * dense_3) const {
class llama_kv_cache_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;
+class llama_memory_hybrid_iswa_context;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
const llama_memory_hybrid_context * mctx;
};
+class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
+public:
+ llm_graph_input_mem_hybrid_iswa(
+ const llama_cparams & cparams,
+ std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
+ std::unique_ptr<llm_graph_input_rs> inp_rs,
+ const llama_memory_hybrid_iswa_context * mctx) :
+ inp_attn(std::move(inp_attn)),
+ inp_rs(std::move(inp_rs)),
+ cparams(cparams),
+ mctx(mctx) { }
+ virtual ~llm_graph_input_mem_hybrid_iswa() = default;
+
+ void set_input(const llama_ubatch * ubatch) override;
+
+ bool can_reuse(const llm_graph_params & params) override;
+
+ std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
+ std::unique_ptr<llm_graph_input_rs> inp_rs;
+
+ llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
+ llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
+
+ const llama_cparams cparams;
+
+ const llama_memory_hybrid_iswa_context * mctx;
+};
+
class llm_graph_input_sampling : public llm_graph_input_i {
public:
llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
+ llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
+
//
// pooling
//
--- /dev/null
+#include "llama-memory-hybrid-iswa.h"
+
+#include "llama-impl.h"
+#include "llama-model.h"
+#include "llama-context.h"
+
+//
+// llama_memory_hybrid_iswa
+//
+
+llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
+ const llama_model & model,
+ /* attn */
+ ggml_type type_k,
+ ggml_type type_v,
+ bool v_trans,
+ bool swa_full,
+ uint32_t kv_size,
+ uint32_t n_ubatch,
+ uint32_t n_pad,
+ /* recurrent */
+ ggml_type type_r,
+ ggml_type type_s,
+ uint32_t rs_size,
+ /* common */
+ uint32_t n_seq_max,
+ bool offload,
+ bool unified,
+ /* layer filters */
+ const layer_filter_cb & filter_attn,
+ const layer_filter_cb & filter_recr) :
+ hparams(model.hparams),
+ mem_attn(new llama_kv_cache_iswa(
+ model,
+ type_k,
+ type_v,
+ v_trans,
+ offload,
+ swa_full,
+ unified,
+ kv_size,
+ n_seq_max,
+ n_ubatch,
+ n_pad,
+ filter_attn == nullptr ?
+ [&](int32_t il) { return !hparams.is_recurrent(il); }
+ : filter_attn,
+ nullptr
+ )),
+ mem_recr(new llama_memory_recurrent(
+ model,
+ type_r,
+ type_s,
+ offload,
+ rs_size,
+ n_seq_max,
+ filter_recr == nullptr ?
+ [&](int32_t il) { return hparams.is_recurrent(il); }
+ : filter_recr
+ )) {}
+
+llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+ do {
+ balloc.split_reset();
+
+ // follow the recurrent pattern for creating the ubatch splits
+ std::vector<llama_ubatch> ubatches;
+
+ while (true) {
+ llama_ubatch ubatch;
+
+ if (embd_all) {
+ // if all tokens are output, split by sequence
+ ubatch = balloc.split_seq(n_ubatch);
+ } else {
+ // TODO: non-sequential equal split can be done if using unified KV cache
+ // for simplicity, we always use sequential equal split for now
+ ubatch = balloc.split_equal(n_ubatch, true);
+ }
+
+ if (ubatch.n_tokens == 0) {
+ break;
+ }
+
+ ubatches.push_back(std::move(ubatch)); // NOLINT
+ }
+
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
+ // failed to find a suitable split
+ break;
+ }
+
+ // prepare the recurrent batches first
+ if (!mem_recr->prepare(ubatches)) {
+ // TODO: will the recurrent cache be in an undefined context at this point?
+ LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
+ return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+ }
+
+ // prepare the attention cache (iswa version returns both base and swa slot infos)
+ auto sinfos_base = mem_attn->get_base()->prepare(ubatches);
+ if (sinfos_base.empty()) {
+ LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__);
+ return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+ }
+
+ auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches);
+ if (sinfos_swa.empty()) {
+ LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__);
+ return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+ }
+
+ return std::make_unique<llama_memory_hybrid_iswa_context>(
+ this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
+ } while(false);
+
+ return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+}
+
+llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() {
+ return std::make_unique<llama_memory_hybrid_iswa_context>(this);
+}
+
+llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) {
+ return std::make_unique<llama_memory_hybrid_iswa_context>(this, lctx, optimize);
+}
+
+bool llama_memory_hybrid_iswa::get_can_shift() const {
+ // Shifting is trivially supported for recurrent
+ return mem_attn->get_can_shift();
+}
+
+void llama_memory_hybrid_iswa::clear(bool data) {
+ mem_attn->clear(data);
+ mem_recr->clear(data);
+}
+
+bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+ // Try removing from the recurrent cache first since it may fail. If it does
+ // fail, the cache will not have been mutated.
+ if (!mem_recr->seq_rm(seq_id, p0, p1)) {
+ return false;
+ }
+ return mem_attn->seq_rm(seq_id, p0, p1);
+}
+
+void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+ mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+ mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) {
+ mem_attn->seq_keep(seq_id);
+ mem_recr->seq_keep(seq_id);
+}
+
+void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+ mem_attn->seq_add(seq_id, p0, p1, shift);
+ mem_recr->seq_add(seq_id, p0, p1, shift);
+}
+
+void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+ mem_attn->seq_div(seq_id, p0, p1, d);
+ mem_recr->seq_div(seq_id, p0, p1, d);
+}
+
+llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const {
+ // the min of the total cache is the max of the two caches' min values
+ return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
+}
+
+llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const {
+ // the max of the total cache is the min of the two caches' max values
+ return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
+}
+
+std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid_iswa::memory_breakdown() const {
+ std::map<ggml_backend_buffer_type_t, size_t> mb = mem_attn->memory_breakdown();
+ for (const auto & buft_size : mem_recr->memory_breakdown()) {
+ mb[buft_size.first] += buft_size.second;
+ }
+ return mb;
+}
+
+void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+ mem_attn->state_write(io, seq_id, flags);
+ mem_recr->state_write(io, seq_id, flags);
+}
+
+void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+ mem_attn->state_read(io, seq_id, flags);
+ mem_recr->state_read(io, seq_id, flags);
+}
+
+llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const {
+ return mem_attn.get();
+}
+
+llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const {
+ return mem_recr.get();
+}
+
+//
+// llama_memory_hybrid_iswa_context
+//
+
+llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {}
+
+llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) :
+ ctx_attn(mem->get_mem_attn()->init_full()),
+ ctx_recr(mem->get_mem_recr()->init_full()),
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
+}
+
+llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
+ llama_memory_hybrid_iswa * mem,
+ llama_context * lctx,
+ bool optimize) :
+ ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
+ ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
+}
+
+llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
+ llama_memory_hybrid_iswa * mem,
+ slot_info_vec_t sinfos_base,
+ slot_info_vec_t sinfos_swa,
+ std::vector<llama_ubatch> ubatches) :
+ ubatches(std::move(ubatches)),
+ // note: here we copy the ubatches. not sure if this is ideal
+ ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)),
+ ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
+}
+
+bool llama_memory_hybrid_iswa_context::next() {
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+ ctx_attn->next();
+ ctx_recr->next();
+
+ if (++i_next >= ubatches.size()) {
+ return false;
+ }
+
+ return true;
+}
+
+bool llama_memory_hybrid_iswa_context::apply() {
+ assert(!llama_memory_status_is_fail(status));
+
+ bool res = true;
+
+ res = res & ctx_attn->apply();
+ res = res & ctx_recr->apply();
+
+ return res;
+}
+
+llama_memory_status llama_memory_hybrid_iswa_context::get_status() const {
+ return status;
+}
+
+const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const {
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+ return ubatches[i_next];
+}
+
+const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const {
+ return static_cast<const llama_kv_cache_iswa_context *>(ctx_attn.get());
+}
+
+const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const {
+ return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
+}
--- /dev/null
+#pragma once
+
+#include "llama-batch.h"
+#include "llama-graph.h"
+#include "llama-kv-cache-iswa.h"
+#include "llama-memory.h"
+#include "llama-memory-recurrent.h"
+
+#include <memory>
+#include <vector>
+
+//
+// llama_memory_hybrid_iswa
+//
+
+// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to
+// support models where each layer may be either attention-based (with SWA support) or recurrent
+
+class llama_memory_hybrid_iswa : public llama_memory_i {
+public:
+ llama_memory_hybrid_iswa(
+ const llama_model & model,
+ /* attn */
+ ggml_type type_k,
+ ggml_type type_v,
+ bool v_trans,
+ bool swa_full,
+ uint32_t kv_size,
+ uint32_t n_ubatch,
+ uint32_t n_pad,
+ /* recurrent */
+ ggml_type type_r,
+ ggml_type type_s,
+ uint32_t rs_size,
+ /* common */
+ uint32_t n_seq_max,
+ bool offload,
+ bool unified,
+ /* layer filters */
+ const layer_filter_cb & filter_attn = nullptr,
+ const layer_filter_cb & filter_recr = nullptr);
+
+ ~llama_memory_hybrid_iswa() = default;
+
+ //
+ // llama_memory_i
+ //
+
+ llama_memory_context_ptr init_batch(
+ llama_batch_allocr & balloc,
+ uint32_t n_ubatch,
+ bool embd_all) override;
+
+ llama_memory_context_ptr init_full() override;
+
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
+
+ bool get_can_shift() const override;
+
+ void clear(bool data) override;
+
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+ void seq_keep(llama_seq_id seq_id) override;
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
+
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+ std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
+
+ // state write/load
+
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
+
+ //
+ // llama_memory_hybrid_iswa specific API
+ //
+
+ llama_kv_cache_iswa * get_mem_attn() const;
+ llama_memory_recurrent * get_mem_recr() const;
+
+private:
+ const llama_hparams & hparams;
+
+ const std::unique_ptr<llama_kv_cache_iswa> mem_attn;
+ const std::unique_ptr<llama_memory_recurrent> mem_recr;
+};
+
+class llama_memory_hybrid_iswa_context : public llama_memory_context_i {
+public:
+ using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
+
+ // init failure
+ explicit llama_memory_hybrid_iswa_context(llama_memory_status status);
+
+ // init full
+ explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem);
+
+ // init update
+ explicit llama_memory_hybrid_iswa_context(
+ llama_memory_hybrid_iswa * mem,
+ llama_context * lctx,
+ bool optimize);
+
+ // init success
+ llama_memory_hybrid_iswa_context(
+ llama_memory_hybrid_iswa * mem,
+ slot_info_vec_t sinfos_base,
+ slot_info_vec_t sinfos_swa,
+ std::vector<llama_ubatch> ubatches);
+
+ ~llama_memory_hybrid_iswa_context() = default;
+
+ bool next() override;
+ bool apply() override;
+
+ llama_memory_status get_status() const override;
+ const llama_ubatch & get_ubatch() const override;
+
+ //
+ // llama_memory_hybrid_iswa_context
+ //
+
+ const llama_kv_cache_iswa_context * get_attn() const;
+ const llama_memory_recurrent_context * get_recr() const;
+
+private:
+ // the index of the next ubatch to process
+ size_t i_next = 0;
+
+ std::vector<llama_ubatch> ubatches;
+
+ const llama_memory_context_ptr ctx_attn;
+ const llama_memory_context_ptr ctx_recr;
+
+ const llama_memory_status status;
+};
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-memory-hybrid.h"
+#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
#include "ggml-cpp.h"
};
}
- res = new llama_memory_hybrid(
- /* model */ *this,
- /* attn_type_k */ params.type_k,
- /* attn_type_v */ params.type_v,
- /* attn_v_trans */ !cparams.flash_attn,
- /* attn_kv_size */ cparams.n_ctx,
- /* attn_n_pad */ 1,
- /* attn_n_swa */ hparams.n_swa,
- /* attn_swa_type */ hparams.swa_type,
- /* recurrent_type_k */ GGML_TYPE_F32,
- /* recurrent_type_v */ GGML_TYPE_F32,
- /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
- /* n_seq_max */ cparams.n_seq_max,
- /* offload */ cparams.offload_kqv,
- /* unified */ cparams.kv_unified,
- /* filter_attn */ std::move(filter_attn),
- /* filter_recr */ std::move(filter_recr));
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
+ // Use hybrid-iswa for hybrid models with SWA
+ res = new llama_memory_hybrid_iswa(
+ /* model */ *this,
+ /* attn_type_k */ params.type_k,
+ /* attn_type_v */ params.type_v,
+ /* attn_v_trans */ !cparams.flash_attn,
+ /* attn_swa_full */ params.swa_full,
+ /* attn_kv_size */ cparams.n_ctx,
+ /* attn_n_ubatch */ cparams.n_ubatch,
+ /* attn_n_pad */ 1,
+ /* recurrent_type_r */ GGML_TYPE_F32,
+ /* recurrent_type_s */ GGML_TYPE_F32,
+ /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max),
+ /* n_seq_max */ cparams.n_seq_max,
+ /* offload */ cparams.offload_kqv,
+ /* unified */ cparams.kv_unified,
+ /* filter_attn */ std::move(filter_attn),
+ /* filter_recr */ std::move(filter_recr));
+ } else {
+ res = new llama_memory_hybrid(
+ /* model */ *this,
+ /* attn_type_k */ params.type_k,
+ /* attn_type_v */ params.type_v,
+ /* attn_v_trans */ !cparams.flash_attn,
+ /* attn_kv_size */ cparams.n_ctx,
+ /* attn_n_pad */ 1,
+ /* attn_n_swa */ hparams.n_swa,
+ /* attn_swa_type */ hparams.swa_type,
+ /* recurrent_type_k */ GGML_TYPE_F32,
+ /* recurrent_type_v */ GGML_TYPE_F32,
+ /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
+ /* n_seq_max */ cparams.n_seq_max,
+ /* offload */ cparams.offload_kqv,
+ /* unified */ cparams.kv_unified,
+ /* filter_attn */ std::move(filter_attn),
+ /* filter_recr */ std::move(filter_recr));
+ }
} else {
llama_memory_i::layer_reuse_cb reuse = nullptr;