}
}
-static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
+static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
- const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
- (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
- (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
- (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
+ const char * swa_type_str = "unknown";
+
+ switch (swa_type) {
+ case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
+ case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
+ case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
+ case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
+ };
+
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
- GGML_ASSERT(kq_mask);
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
-
- float * data = (float *) kq_mask->data;
-
- // [TAG_NO_CACHE_ISWA]
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
+ const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
+ for (int h = 0; h < 1; ++h) {
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
+ const llama_pos p1 = ubatch->pos[i1];
- for (int h = 0; h < 1; ++h) {
- for (int i1 = 0; i1 < n_tokens; ++i1) {
- const llama_seq_id s1 = ubatch->seq_id[i1][0];
+ const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
- for (int i0 = 0; i0 < n_tokens; ++i0) {
- float f = -INFINITY;
-
- for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
const llama_seq_id s0 = ubatch->seq_id[i0][0];
+ const llama_pos p0 = ubatch->pos[i0];
+ // mask different sequences
if (s0 != s1) {
- continue; // skip different sequences
+ continue;
}
- if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
- continue; // skip future tokens for causal attention
+ // mask future tokens
+ if (cparams.causal_attn && p0 > p1) {
+ continue;
}
- // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
- //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
- // continue; // skip masked tokens for SWA
- //}
-
- // TODO: reimplement this like in llama_kv_cache_unified
- if (hparams.use_alibi) {
- f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
- } else {
- f = 0.0f;
+ // apply SWA if any
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
+ continue;
}
+
+ data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
}
- data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
}
}
+ };
+
+ {
+ GGML_ASSERT(self_kq_mask);
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
+
+ float * data = (float *) self_kq_mask->data;
+
+ std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
+
+ fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
+
+ if (debug) {
+ print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
+ }
}
- if (debug) {
- print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
+
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
+ GGML_ASSERT(self_kq_mask_swa);
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
+
+ float * data = (float *) self_kq_mask_swa->data;
+
+ std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
+
+ fill_mask(data, hparams.n_swa, hparams.swa_type);
+
+ if (debug) {
+ print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
+ }
}
}
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
- const auto n_kv = k->ne[1];
-
ggml_tensor * cur;
// TODO: replace hardcoded padding with ggml-provided padding
- if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
+ if (cparams.flash_attn && kq_b == nullptr) {
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
if (v_trans) {
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
- inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
- ggml_set_input(inp->kq_mask);
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+ ggml_set_input(inp->self_kq_mask);
+
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
- inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+ ggml_set_input(inp->self_kq_mask_swa);
+
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
+ } else {
+ inp->self_kq_mask_swa = nullptr;
+ inp->self_kq_mask_swa_cnv = nullptr;
+ }
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
}
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
- const auto & kq_mask = inp->get_kq_mask();
+ const bool is_swa = hparams.is_swa(il);
+
+ const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
// [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams