common_init();
- int is_pp_shared = params.is_pp_shared;
+ int is_pp_shared = params.is_pp_shared;
+ int is_tg_separate = params.is_tg_separate;
std::vector<int> n_pp = params.n_pp;
std::vector<int> n_tg = params.n_tg;
// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch, bool synchronize) {
- for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
- const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
+ for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
+ const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
llama_batch batch_view = {
n_tokens,
if (!params.batched_bench_output_jsonl) {
LOG("\n");
- LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, int(params.flash_attn_type), params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
+ LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, is_tg_separate = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, int(params.flash_attn_type), is_pp_shared, is_tg_separate, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG("\n");
LOG("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
LOG("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");
const auto t_tg_start = ggml_time_us();
- for (int i = 0; i < tg; ++i) {
- common_batch_clear(batch);
-
+ if (is_tg_separate) {
+ // decode pattern:
+ // 0 0 0 ... 1 1 1 ... 2 2 2 ... 3 3 3 ...
for (int j = 0; j < pl; ++j) {
- common_batch_add(batch, get_token_rand(), pp + i, { j }, true);
+ for (int i = 0; i < tg; ++i) {
+ common_batch_clear(batch);
+
+ common_batch_add(batch, get_token_rand(), pp + i, { j }, true);
+
+ if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
+ LOG_ERR("%s: llama_decode() failed\n", __func__);
+ return 1;
+ }
+ }
}
+ } else {
+ // decode pattern:
+ // 0123 0123 0123 ...
+ for (int i = 0; i < tg; ++i) {
+ common_batch_clear(batch);
- if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
- LOG_ERR("%s: llama_decode() failed\n", __func__);
- return 1;
+ for (int j = 0; j < pl; ++j) {
+ common_batch_add(batch, get_token_rand(), pp + i, { j }, true);
+ }
+
+ if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
+ LOG_ERR("%s: llama_decode() failed\n", __func__);
+ return 1;
+ }
}
}