]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
graph : fix geglu (#14077)
authorGeorgi Gerganov <redacted>
Mon, 9 Jun 2025 14:17:31 +0000 (17:17 +0300)
committerGitHub <redacted>
Mon, 9 Jun 2025 14:17:31 +0000 (17:17 +0300)
ggml-ci

src/llama-graph.cpp

index 55390d42e72caa0e70ec779b0040aa99cfaf1df2..27c9ab74be1125e6b7811c75a0a5ae92bf6be3a0 100644 (file)
@@ -663,22 +663,14 @@ ggml_tensor * llm_graph_context::build_ffn(
             {
                 // Split into two equal parts
                 int64_t split_point = cur->ne[0] / 2;
-                ggml_tensor * output_ffn_up = ggml_cont(ctx0, ggml_view_2d(
-                                                ctx0, cur, split_point,
-                                                cur->ne[1], cur->nb[1], 0
-                                            ));
-                ggml_tensor * output_ffn_gate = ggml_cont(ctx0, ggml_view_2d(
-                                                ctx0, cur, split_point,
-                                                cur->ne[1], cur->nb[1],
-                                                split_point * ggml_element_size(cur)
-                                            ));
-
-                // Apply GELU activation function to the first part
-                output_ffn_up = ggml_gelu(ctx0, output_ffn_up);
-                cb(output_ffn_up, "ffn_gelu", il);
-
-                // Element-wise multiplication between the activated part and the gate part
-                cur = ggml_mul(ctx0, output_ffn_up, output_ffn_gate);
+                // TODO: these conts should not be needed
+                ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
+                ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
+
+                x0 = ggml_gelu(ctx0, x0);
+                cb(x0, "ffn_gelu", il);
+
+                cur = ggml_mul(ctx0, x0, x1);
                 cb(cur, "ffn_geglu", il);
             } break;
     }