From: Johannes Gäßler Date: Thu, 3 Jul 2025 15:05:18 +0000 (+0200) Subject: ggml: backward pass for split swiglu (llama/14483) X-Git-Tag: upstream/1.8.0~436 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=4aaf8114e7953715ef9880463d0de46dd294e13d;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp ggml: backward pass for split swiglu (llama/14483) --- diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b89a68db..68768842 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6042,13 +6042,28 @@ static void ggml_compute_backward( } GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); } break; + case GGML_OP_GLU: { + switch (ggml_get_glu_op(tensor)) { + case GGML_GLU_OP_SWIGLU: { + if (src0_needs_grads) { + GGML_ASSERT(src1 && "backward pass only implemented for split swiglu"); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0)); + } + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad)); + } + } break; + default: { + GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor))); + } //break; + } + } break; case GGML_OP_NONE: { // noop } break; case GGML_OP_COUNT: default: { - fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); - GGML_ABORT("fatal error"); + GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); } //break; }