]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix parallel processing for lfm2 (#14705)
authorTarek Dakhran <redacted>
Thu, 17 Jul 2025 07:22:11 +0000 (09:22 +0200)
committerGitHub <redacted>
Thu, 17 Jul 2025 07:22:11 +0000 (09:22 +0200)
src/llama-model.cpp

index 9d8a686e0a571a122313e82461c1105178273a72..cdf1e424294e58bfa6a773e136f57cf285feb1aa 100644 (file)
@@ -16554,7 +16554,19 @@ struct llm_build_lfm2 : public llm_graph_context {
                                         ggml_tensor        * cur,
                                         llm_graph_input_rs * inp_recr,
                                         int                il) {
-        const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
+        const auto *   mctx_cur     = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
+        const uint32_t kv_head      = mctx_cur->get_head();
+        const int64_t  n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t  n_seqs       = ubatch.n_seqs;
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+        GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
+        const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
+
+        // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+        cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
 
         auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
         cb(bcx, "model.layers.{}.conv.in_proj", il);
@@ -16562,38 +16574,48 @@ struct llm_build_lfm2 : public llm_graph_context {
         constexpr auto n_chunks = 3;
         GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
         auto const chunk_size = bcx->ne[0] / n_chunks;
-        auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx));
-        auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx));
-        auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx));
+        auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 0*chunk_size*ggml_element_size(bcx));
+        auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 1*chunk_size*ggml_element_size(bcx));
+        auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 2*chunk_size*ggml_element_size(bcx));
 
         auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
 
-        // read conv state directly, with build_rs generation is slower
-        ggml_tensor * conv_state = mctx_cur->get_r_l(il);
-        const int64_t n_seqs  = ubatch.n_seqs;
-        ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
-        conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
+        // read conv state
+        auto * conv_state = mctx_cur->get_r_l(il);
+        auto * conv_rs    = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
+        auto * conv       = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
 
         bx = ggml_concat(ctx0, conv, bx, 0);
         GGML_ASSERT(bx->ne[0] > conv->ne[0]);
 
-        auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
+        // last d_conv columns is a new conv state
+        auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], (bx->ne[0] - conv->ne[0])*ggml_element_size(bx));
         GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
 
-        // write conv state
-        ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state));
+        // write new conv conv state
+        ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    new_conv,
+                    ggml_view_1d(
+                        ctx0,
+                        conv_state,
+                        ggml_nelements(new_conv),
+                        kv_head*d_conv*n_embd*ggml_element_size(new_conv)
+                        )
+                    )
+                );
 
         auto * conv_kernel = model.layers[il].shortconv.conv;
-        GGML_ASSERT(hparams.n_shortconv_l_cache > 0);
-
-        // construct ssm_conv op
-        ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
+        auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
         cb(conv_out, "model.layers.{}.conv.conv", il);
 
         auto * y = ggml_mul(ctx0, c, conv_out);
-
         y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
         cb(y, "model.layers.{}.conv.out_proj", il);
+        // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+        y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
 
         return y;
     }