LLM_KV_ATTENTION_Q_LORA_RANK,
LLM_KV_ATTENTION_KV_LORA_RANK,
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
+ LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE,
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
+ { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
uint32_t n_head_kv;
uint32_t n_layer;
uint32_t n_rot;
+ uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_ff;
if (this->n_head_kv != other.n_head_kv) return true;
if (this->n_layer != other.n_layer) return true;
if (this->n_rot != other.n_rot) return true;
+ if (this->n_swa != other.n_swa) return true;
if (this->n_embd_head_k != other.n_embd_head_k) return true;
if (this->n_embd_head_v != other.n_embd_head_v) return true;
if (this->n_ff != other.n_ff) return true;
void * abort_callback_data = nullptr;
// input tensors
- struct ggml_tensor * inp_tokens; // I32 [n_batch]
- struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
- struct ggml_tensor * inp_pos; // I32 [n_batch]
- struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
- struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
- struct ggml_tensor * inp_K_shift; // I32 [kv_size]
- struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
- struct ggml_tensor * inp_cls; // I32 [n_batch]
- struct ggml_tensor * inp_s_copy; // I32 [kv_size]
- struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
- struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
+ struct ggml_tensor * inp_tokens; // I32 [n_batch]
+ struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
+ struct ggml_tensor * inp_pos; // I32 [n_batch]
+ struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
+ struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
+ struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
+ struct ggml_tensor * inp_K_shift; // I32 [kv_size]
+ struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
+ struct ggml_tensor * inp_cls; // I32 [n_batch]
+ struct ggml_tensor * inp_s_copy; // I32 [kv_size]
+ struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
+ struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
// control vectors
struct llama_control_vector cvec;
} break;
case LLM_ARCH_GEMMA2:
{
+ hparams.n_swa = 4096; // default value of gemma 2
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
+ LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
ctx0 = ggml_init(params);
- lctx.inp_tokens = nullptr;
- lctx.inp_embd = nullptr;
- lctx.inp_pos = nullptr;
- lctx.inp_out_ids = nullptr;
- lctx.inp_KQ_mask = nullptr;
- lctx.inp_K_shift = nullptr;
- lctx.inp_mean = nullptr;
- lctx.inp_cls = nullptr;
- lctx.inp_s_copy = nullptr;
- lctx.inp_s_mask = nullptr;
- lctx.inp_s_seq = nullptr;
+ lctx.inp_tokens = nullptr;
+ lctx.inp_embd = nullptr;
+ lctx.inp_pos = nullptr;
+ lctx.inp_out_ids = nullptr;
+ lctx.inp_KQ_mask = nullptr;
+ lctx.inp_KQ_mask_swa = nullptr;
+ lctx.inp_K_shift = nullptr;
+ lctx.inp_mean = nullptr;
+ lctx.inp_cls = nullptr;
+ lctx.inp_s_copy = nullptr;
+ lctx.inp_s_mask = nullptr;
+ lctx.inp_s_seq = nullptr;
}
void free() {
cb(lctx.inp_K_shift, "K_shift", -1);
ggml_set_input(lctx.inp_K_shift);
-
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * rope_factors = build_rope_factors(il);
struct ggml_tensor * tmp =
}
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
- if (causal) {
- lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
- } else {
- lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
- }
+ lctx.inp_KQ_mask = causal
+ ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
+ : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(lctx.inp_KQ_mask);
+
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
}
+ struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
+ GGML_ASSERT(hparams.n_swa > 0);
+
+ lctx.inp_KQ_mask_swa = causal
+ ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
+ : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+ cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
+ ggml_set_input(lctx.inp_KQ_mask_swa);
+
+ return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
+ }
+
struct ggml_tensor * build_inp_mean() {
lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
cb(lctx.inp_mean, "inp_mean", -1);
struct ggml_tensor * inp_pos = build_inp_pos();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
- struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+ // gemma 2 requires different mask for layers using sliding window (SWA)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
+ struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
for (int il = 0; il < n_layer; ++il) {
+ // (il % 2) layers use SWA
+ struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
+
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
- Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+ Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
cur = llm_build_norm(ctx0, cur, hparams,
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
- float * data = (float *) lctx.inp_KQ_mask->data;
+ float * data = (float *) lctx.inp_KQ_mask->data;
+ float * data_swa = nullptr;
+
+ if (lctx.inp_KQ_mask_swa) {
+ data_swa = (float *) lctx.inp_KQ_mask_swa->data;
+ }
// For causal attention, use only the previous KV cells
// of the correct sequence for each token of the batch.
}
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+
+ // may need to cut off old tokens for sliding window
+ if (data_swa) {
+ if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
+ f = -INFINITY;
+ }
+ data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+ }
}
}