GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
}
- res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
- res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
+ res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
+ res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
res.strm[s] = seq_to_stream[seq_id];
res.idxs[s].reserve(n_tokens);
return result;
}
-uint32_t llama_kv_cache::get_n_kv() const {
+uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;
- for (uint32_t s = 0; s < n_stream; ++s) {
- const auto & cells = v_cells[s];
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+ const auto & cells = v_cells[sinfo.strm[s]];
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
}
// note: v->nb[1] <= v->nb[2]
return ggml_view_4d(ctx, v,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
- ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
- ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
- ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
+ ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
+ ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
+ ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
}
// note: v->nb[1] > v->nb[2]
return ggml_view_4d(ctx, v,
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
- ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
- ggml_row_size(v->type, kv_size), // v->nb[2]
- ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
+ ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
+ ggml_row_size(v->type, kv_size), // v->nb[2]
+ ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
}
}
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
-
- n_kv = kv->get_n_kv();
+ n_kv = kv->get_n_kv(sinfos[i_cur]);
return true;
}