void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask || self_kq_mask_swa) {
- // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
- if (cparams.causal_attn) {
- const int64_t n_kv = kv_self->n;
- const int64_t n_tokens = ubatch->n_tokens;
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
- const int64_t n_seqs = ubatch->n_seqs;
-
- float * data = nullptr;
- float * data_swa = nullptr;
-
- if (self_kq_mask) {
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
- data = (float *) self_kq_mask->data;
- }
+ const int64_t n_kv = kv_self->n;
+ const int64_t n_tokens = ubatch->n_tokens;
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+ const int64_t n_seqs = ubatch->n_seqs;
- if (self_kq_mask_swa) {
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
- data_swa = (float *) self_kq_mask_swa->data;
- }
+ float * data = nullptr;
+ float * data_swa = nullptr;
- // For causal attention, use only the previous KV cells
- // of the correct sequence for each token of the ubatch.
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
- for (int h = 0; h < 1; ++h) {
- for (int s = 0; s < n_seqs; ++s) {
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
+ if (self_kq_mask) {
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
+ data = (float *) self_kq_mask->data;
+ }
- for (int j = 0; j < n_seq_tokens; ++j) {
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
+ if (self_kq_mask_swa) {
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
+ data_swa = (float *) self_kq_mask_swa->data;
+ }
- for (int i = 0; i < n_kv; ++i) {
- float f;
- if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
- f = -INFINITY;
+ // Use only the previous KV cells of the correct sequence for each token of the ubatch.
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
+ // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
+ // Causal mask:
+ // xxx-------
+ // xxxx------
+ // xxxxx-----
+ // Non-causal mask:
+ // xxxxx-----
+ // xxxxx-----
+ // xxxxx-----
+ // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
+ for (int h = 0; h < 1; ++h) {
+ for (int s = 0; s < n_seqs; ++s) {
+ const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+ for (int j = 0; j < n_seq_tokens; ++j) {
+ const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
+ for (int i = 0; i < n_kv; ++i) {
+ float f;
+ // mask the token if:
+ if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
+ || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
+ ) {
+ f = -INFINITY;
+ } else {
+ if (hparams.use_alibi) {
+ f = -std::abs(kv_self->cells[i].pos - pos);
} else {
- if (hparams.use_alibi) {
- f = -std::abs(kv_self->cells[i].pos - pos);
- } else {
- f = 0.0f;
- }
- }
-
- if (data) {
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+ f = 0.0f;
}
+ }
- // may need to cut off old tokens for sliding window
- if (data_swa) {
- if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
- f = -INFINITY;
- }
- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
- }
+ if (data) {
+ data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
- }
- }
- if (data) {
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
- for (int j = 0; j < n_kv; ++j) {
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+ // may need to cut off old tokens for sliding window
+ if (data_swa) {
+ if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
+ f = -INFINITY;
+ }
+ data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
}
}
+ }
- if (data_swa) {
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
- for (int j = 0; j < n_kv; ++j) {
- data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
- }
+ // mask padded tokens
+ if (data) {
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+ for (int j = 0; j < n_kv; ++j) {
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
- } else {
- const int64_t n_tokens = ubatch->n_tokens;
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
- const int64_t n_seqs = ubatch->n_seqs;
- // when using kv cache, the mask needs to match the kv cache size
- const int64_t n_stride = n_tokens;
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
-
- float * data = (float *) self_kq_mask->data;
-
- for (int h = 0; h < 1; ++h) {
- for (int s1 = 0; s1 < n_seqs; ++s1) {
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
-
- for (int j = 0; j < n_seq_tokens; ++j) {
- const int32_t tj = s1*n_seq_tokens + j;
-
- for (int s0 = 0; s0 < n_seqs; ++s0) {
- for (int i = 0; i < n_seq_tokens; ++i) {
- const int32_t ti = s0*n_seq_tokens + i;
- float f = -INFINITY;
-
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
- if (ubatch->seq_id[s0][s] == seq_id) {
- if (hparams.use_alibi) {
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
- } else {
- f = 0.0f;
- }
- break;
- }
- }
-
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
- }
- }
-
- for (int i = n_tokens; i < n_stride; ++i) {
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
- }
+ // mask padded tokens
+ if (data_swa) {
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+ for (int j = 0; j < n_kv; ++j) {
+ data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}