]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: backward pass for split swiglu (#14483)
authorJohannes Gäßler <redacted>
Thu, 3 Jul 2025 15:05:18 +0000 (17:05 +0200)
committerGitHub <redacted>
Thu, 3 Jul 2025 15:05:18 +0000 (17:05 +0200)
ggml/src/ggml.c
tests/test-backend-ops.cpp

index fdb57e178588498ffe20e3b6c9d062d03cc02aa3..b481193da3d8cd2ee837e970ef160583d6a792f1 100644 (file)
@@ -6050,13 +6050,28 @@ static void ggml_compute_backward(
             }
             GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
         } break;
+        case GGML_OP_GLU: {
+            switch (ggml_get_glu_op(tensor)) {
+                case GGML_GLU_OP_SWIGLU: {
+                    if (src0_needs_grads) {
+                        GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
+                    }
+                    if (src1_needs_grads) {
+                        ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
+                    }
+                } break;
+                default: {
+                    GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
+                } //break;
+            }
+        } break;
         case GGML_OP_NONE: {
             // noop
         } break;
         case GGML_OP_COUNT:
         default: {
-            fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
-            GGML_ABORT("fatal error");
+            GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
         } //break;
     }
 
index 2ab6dd06fc88e4c862d98af3422590a0ccf4e121..c76635793f3822a0cf31afe5e85a274bbf53a2c7 100644 (file)
@@ -1175,21 +1175,25 @@ struct test_glu_split : public test_case {
         if (v & 1) {
             auto ne = ne_a; ne[0] *= 3;
             a = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_param(a);
             ggml_set_name(a, "a");
 
             a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
             ggml_set_name(a, "view_of_a");
 
             b = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_param(b);
             ggml_set_name(b, "b");
 
             b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
             ggml_set_name(a, "view_of_b");
         } else {
             a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            ggml_set_param(a);
             ggml_set_name(a, "a");
 
             b = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            ggml_set_param(b);
             ggml_set_name(b, "b");
         }