float * data = (float *) kq_mask->data;
+ // [TAG_NO_CACHE_ISWA]
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
+
for (int h = 0; h < 1; ++h) {
for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0];
continue; // skip future tokens for causal attention
}
- if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
- continue; // skip masked tokens for SWA
- }
+ // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
+ //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
+ // continue; // skip masked tokens for SWA
+ //}
// TODO: reimplement this like in llama_kv_cache_unified
if (hparams.use_alibi) {
return res;
}
-bool llama_hparams::is_masked_swa(llama_pos p0, llama_pos p1) const {
+bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
assert(p0 >= 0 && p1 >= 0);
switch (swa_type) {
// number of layers for which has_kv() returns true
uint32_t n_layer_kv() const;
- bool is_masked_swa(llama_pos p0, llama_pos p1) const;
+ // note that this function uses different SWA parameters from those in the hparams
+ // TODO: think of a better place for this function
+ // TODO: pack the SWA params in a struct?
+ static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
};
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
kv_base = std::make_unique<llama_kv_cache>(
model, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
- 0, filter_base, reuse);
+ 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache>(
model, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
- hparams.n_swa, filter_swa, reuse);
+ hparams.n_swa, hparams.swa_type, filter_swa, reuse);
}
void llama_kv_cache_iswa::clear(bool data) {
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
+ llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) :
model(model), hparams(model.hparams), v_trans(v_trans),
- n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa) {
+ n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
GGML_ASSERT(kv_size % n_pad == 0);
}
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
- return hparams.is_masked_swa(p0, p1);
+ return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
}
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
+ llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0;
+ // this is the SWA type of the cache - not to be confused with the model SWA type
+ const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
+
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
+ llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
n_seq_max,
n_pad,
n_swa,
+ swa_type,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
+ llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
- auto * inp_attn = build_attn_inp_no_cache();
+ // TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
+ auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids();
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_NEO_BERT:
case LLM_ARCH_WAVTOKENIZER_DEC:
- case LLM_ARCH_GEMMA_EMBEDDING:
+ //case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
{
/* attn_kv_size */ cparams.n_ctx,
/* attn_n_pad */ padding,
/* attn_n_swa */ hparams.n_swa,
+ /* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32,
/* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
cparams.n_seq_max,
padding,
hparams.n_swa,
+ hparams.swa_type,
nullptr,
nullptr);
}