]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml: backward pass for split swiglu (llama/14483)
authorJohannes Gäßler <redacted>
Thu, 3 Jul 2025 15:05:18 +0000 (17:05 +0200)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 16:23:56 +0000 (19:23 +0300)
ggml/src/ggml.c

index b89a68db83c69b447bb82b344d1c48eaa88df1f3..68768842904602657ffe8d8158805d54b8f35a71 100644 (file)
@@ -6042,13 +6042,28 @@ static void ggml_compute_backward(
             }
             GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
         } break;
+        case GGML_OP_GLU: {
+            switch (ggml_get_glu_op(tensor)) {
+                case GGML_GLU_OP_SWIGLU: {
+                    if (src0_needs_grads) {
+                        GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
+                    }
+                    if (src1_needs_grads) {
+                        ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
+                    }
+                } break;
+                default: {
+                    GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
+                } //break;
+            }
+        } break;
         case GGML_OP_NONE: {
             // noop
         } break;
         case GGML_OP_COUNT:
         default: {
-            fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
-            GGML_ABORT("fatal error");
+            GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
         } //break;
     }