From: Sigbjørn Skjæret Date: Mon, 18 Aug 2025 17:30:17 +0000 (+0200) Subject: llama : merge conts and reshapes and remove unnecessary cont (#15380) X-Git-Tag: upstream/0.0.6199~4 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=baa9255a45105d2d3b4ec432af13b7a6eda3ff35;p=pkg%2Fggml%2Fsources%2Fllama.cpp llama : merge conts and reshapes and remove unnecessary cont (#15380) * remove unnecessary conts and merge reshapes * restore necessary conts * merge more conts and reshapes * merge even more conts and reshapes --- diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 23a26f0c..0c149cd4 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6743,9 +6743,9 @@ struct llm_build_falcon : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -7023,9 +7023,9 @@ struct llm_build_dbrx : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -7145,13 +7145,13 @@ struct llm_build_starcoder : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7367,13 +7367,15 @@ struct llm_build_bert : public llm_graph_context { cb(cur, "bqkv", il); } - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } if (model.layers[il].attn_q_norm) { @@ -7381,6 +7383,10 @@ struct llm_build_bert : public llm_graph_context { model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + } else { + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); } if (model.layers[il].attn_k_norm) { @@ -7388,11 +7394,11 @@ struct llm_build_bert : public llm_graph_context { model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + } else { + Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + } // RoPE if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { @@ -7537,9 +7543,9 @@ struct llm_build_neo_bert : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // RoPE Qcur = ggml_rope_ext( @@ -7646,13 +7652,13 @@ struct llm_build_bloom : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7770,7 +7776,7 @@ struct llm_build_mpt : public llm_graph_context { ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7789,17 +7795,18 @@ struct llm_build_mpt : public llm_graph_context { model.layers[il].attn_k_norm_b, LLM_NORM, il); cb(Kcur, "Kcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } else { - Qcur = ggml_cont(ctx0, Qcur); + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - Kcur = ggml_cont(ctx0, Kcur); + Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); cb(Kcur, "Kcur", il); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -8051,9 +8058,9 @@ struct llm_build_qwen : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -9026,21 +9033,21 @@ struct llm_build_phi2 : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9164,21 +9171,21 @@ struct llm_build_phi3 : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9428,17 +9435,17 @@ struct llm_build_gpt2 : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, @@ -9534,9 +9541,9 @@ struct llm_build_codeshell : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -10864,8 +10871,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] all_coefs = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f); // + 1.0 cb(all_coefs, "all_coefs", il); - all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup] - all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] + all_coefs = ggml_transpose(ctx0, all_coefs); // [n_tokens, n_altup] + all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1); ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] @@ -12278,9 +12285,9 @@ struct llm_build_gptneox : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -13413,17 +13420,17 @@ struct llm_build_jais : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)); + ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, @@ -13526,6 +13533,7 @@ struct llm_build_chatglm : public llm_graph_context { } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); @@ -13535,11 +13543,10 @@ struct llm_build_chatglm : public llm_graph_context { } Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -13660,6 +13667,7 @@ struct llm_build_glm4 : public llm_graph_context { } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); @@ -13669,11 +13677,10 @@ struct llm_build_glm4 : public llm_graph_context { } Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -16840,13 +16847,13 @@ private: ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv))); + ggml_tensor * Vcur = ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -16913,15 +16920,13 @@ private: cb(zx, "mamba_in_proj", il); // {8192, 5, 1, 1} -> {8192, 1, 5, 1} zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); - zx = ggml_cont(ctx0, zx); - zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); + zx = ggml_cont_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); cb(zx, "mamba_in_proj_out", il); // split into z and x // => {head_dim * n_heads, n_seq_tokens, n_seqs} ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx)); - x = ggml_cont(ctx0, x); - x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); + x = ggml_cont_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); // x = ggml_permute(ctx0, x, 0, 2, 1, 3); cb(x, "mamba_x_split", il); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index c27f8ebb..099e2823 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -508,8 +508,8 @@ struct clip_graph { const int patches_per_image = n_patches_x; const int kernel_size = hparams.proj_scale_factor; - cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); - cur = ggml_reshape_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size); + cur = ggml_transpose(ctx0, cur); + cur = ggml_cont_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size); // doing a pool2d to reduce the number of output tokens cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0); @@ -537,13 +537,13 @@ struct clip_graph { GGML_ASSERT(scale_factor != 0); cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_4d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, seq / (scale_factor * scale_factor), bsz); @@ -570,13 +570,13 @@ struct clip_graph { // unshuffle h cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height); - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3)); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // unshuffle w - cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor); - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3)); + cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); + cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); // projection cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm @@ -715,15 +715,15 @@ struct clip_graph { auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); inp = ggml_add(ctx0, inp, inp_1); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] - inp = ggml_reshape_4d( + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_cont_4d( ctx0, inp, n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); inp = ggml_reshape_4d( ctx0, inp, n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); - inp = ggml_reshape_3d( + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d( ctx0, inp, n_embd, n_patches_x * n_patches_y, batch_size); } @@ -988,14 +988,14 @@ struct clip_graph { GGML_ASSERT(scale_factor > 0); cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_4d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // flatten to 2D - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_2d(ctx0, cur, n_embd * scale_factor * scale_factor, cur->ne[1] * cur->ne[2]); } @@ -1081,14 +1081,14 @@ struct clip_graph { n_patches_y, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_4d(ctx0, cur, n_embd * scale_factor * scale_factor, n_patches_x / scale_factor, n_patches_y / scale_factor, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // flatten to 2D - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_2d(ctx0, cur, n_embd * scale_factor * scale_factor, n_patches / scale_factor / scale_factor); cb(cur, "pixel_shuffle", -1); @@ -1321,8 +1321,8 @@ struct clip_graph { ggml_tensor * block_1 = nullptr; { // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24] - mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3)); - mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]); + mlp_3 = ggml_permute(ctx0, mlp_3, 1, 0, 2, 3); + mlp_3 = ggml_cont_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]); // stride = 1, padding = 1, bias is nullptr block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1); @@ -1427,9 +1427,9 @@ struct clip_graph { mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b); // mlp_2 ne = [2048, 576, 1, 1] // // AVG Pool Layer 2*2, strides = 2 - mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3)); + mlp_2 = ggml_permute(ctx0, mlp_2, 1, 0, 2, 3); // mlp_2 ne = [576, 2048, 1, 1] - mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); + mlp_2 = ggml_cont_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); // mlp_2 ne [24, 24, 2048, 1] mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); // weight ne = [3, 3, 2048, 1] @@ -1449,8 +1449,8 @@ struct clip_graph { // glm projector else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) { size_t gridsz = (size_t)sqrt(embeddings->ne[1]); - embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); - embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); + embeddings = ggml_permute(ctx0,embeddings,1,0,2,3); + embeddings = ggml_cont_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3)); @@ -2005,7 +2005,6 @@ private: ggml_row_size(cur->type, n_dim), ggml_row_size(cur->type, n_dim*n_head), n_dim/2 * ggml_element_size(cur)); - second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors second = ggml_rope_ext( ctx0, second,