]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: backwards pass for misc. ops, add tests (llama/11257)
authorJohannes Gäßler <redacted>
Thu, 16 Jan 2025 15:43:38 +0000 (16:43 +0100)
committerGeorgi Gerganov <redacted>
Mon, 3 Feb 2025 20:00:57 +0000 (22:00 +0200)
* CUDA: backwards pass for misc. ops, add tests

* remove restrict from pointers

17 files changed:
ggml/include/ggml.h
ggml/src/ggml-alloc.c
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ggml-cpu.cpp
ggml/src/ggml-cuda/cross-entropy-loss.cu
ggml/src/ggml-cuda/getrows.cu
ggml/src/ggml-cuda/getrows.cuh
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/norm.cu
ggml/src/ggml-cuda/norm.cuh
ggml/src/ggml-cuda/out-prod.cu
ggml/src/ggml-cuda/rope.cu
ggml/src/ggml-cuda/softmax.cu
ggml/src/ggml-cuda/softmax.cuh
ggml/src/ggml-cuda/unary.cu
ggml/src/ggml-cuda/unary.cuh
ggml/src/ggml.c

index a9c051cd5d691586fa6a443bb6c656dcbd9b9e72..1198dc1fd93785d1f7eebb04f615e278448524e2 100644 (file)
@@ -1384,16 +1384,20 @@ extern "C" {
             float                 scale,
             float                 max_bias);
 
-    GGML_API struct ggml_tensor * ggml_soft_max_back(
+    GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            float                 scale,
+            float                 max_bias);
 
     // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
+    GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            float                 scale,
+            float                 max_bias);
 
     // rotary position embedding
     // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
index 8dc8226ac4932c946e7ff7969e9b020564685ce1..9a3bf9f29235c60b137da9659f22d5b5f615c421 100644 (file)
@@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
     return true;
 }
 
+// ops that return true for this function must not use restrict pointers for their backend implementations
 static bool ggml_op_can_inplace(enum ggml_op op) {
     switch (op) {
         case GGML_OP_SCALE:
@@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
         case GGML_OP_LOG:
         case GGML_OP_UNARY:
         case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
+        case GGML_OP_SILU_BACK:
         case GGML_OP_RMS_NORM:
+        case GGML_OP_RMS_NORM_BACK:
         case GGML_OP_SOFT_MAX:
+        case GGML_OP_SOFT_MAX_BACK:
             return true;
 
         default:
index 8bf5f781a599ef2f6b88eba42f495d696ef29ce7..8040e20fe9d9a28d9fdace7f41481cf55aa46a19 100644 (file)
@@ -6691,20 +6691,20 @@ static void ggml_compute_forward_silu_back_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * grad = dst->src[1];
+    const struct ggml_tensor * grad = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
 
     assert(ggml_is_contiguous_1(grad));
-    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(src1));
     assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-    assert(ggml_are_same_shape(src0, grad));
+    assert(ggml_are_same_shape(src1, dst));
+    assert(ggml_are_same_shape(src1, grad));
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
+    const int nc = src1->ne[0];
+    const int nr = ggml_nrows(src1);
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -6716,7 +6716,7 @@ static void ggml_compute_forward_silu_back_f32(
     for (int i1 = ir0; i1 < ir1; i1++) {
         ggml_vec_silu_backward_f32(nc,
                 (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])),
+                (float *) ((char *) src1->data + i1*(src1->nb[1])),
                 (float *) ((char *) grad->data + i1*(grad->nb[1])));
 
 #ifndef NDEBUG
@@ -6895,7 +6895,7 @@ static void ggml_compute_forward_norm_f32(
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
 
-    GGML_ASSERT(eps > 0.0f);
+    GGML_ASSERT(eps >= 0.0f);
 
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -6966,7 +6966,7 @@ static void ggml_compute_forward_rms_norm_f32(
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
 
-    GGML_ASSERT(eps > 0.0f);
+    GGML_ASSERT(eps >= 0.0f);
 
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7018,12 +7018,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
+    const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
 
     GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
 
     GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -7042,8 +7043,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
                 const int64_t i12 = i02;
                 const int64_t i13 = i03;
 
-                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
+                const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                const float *  = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
 
                 ggml_float sum_xx  = 0.0;
                 ggml_float sum_xdz = 0.0;
@@ -7066,9 +7067,9 @@ static void ggml_compute_forward_rms_norm_back_f32(
                 {
                     // z = rms_norm(x)
                     //
-                    // rms_norm(src0) =
+                    // rms_norm(src1) =
                     //     scale(
-                    //         src0,
+                    //         src1,
                     //         div(
                     //             1,
                     //             sqrt(
@@ -7076,13 +7077,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
                     //                     scale(
                     //                         sum(
                     //                             sqr(
-                    //                                 src0)),
+                    //                                 src1)),
                     //                         (1.0/N)),
                     //                     eps))));
 
                     // postorder:
                     // ## op    args         grad
-                    // 00 param src0         grad[#00]
+                    // 00 param src1         grad[#00]
                     // 01 const 1
                     // 02 sqr   (#00)        grad[#02]
                     // 03 sum   (#02)        grad[#03]
@@ -7159,6 +7160,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
                 // dx := scale(dx, rrms)
                 float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 
+                // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
                 ggml_vec_cpy_f32  (ne00, dx, x);
                 // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
                 ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
@@ -7750,12 +7752,13 @@ static void ggml_compute_forward_out_prod_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    GGML_ASSERT(ne0  == ne00);
-    GGML_ASSERT(ne1  == ne10);
-    GGML_ASSERT(ne2  == ne02);
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne3  == ne13);
-    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+    GGML_ASSERT(ne2 == ne12);
+    GGML_ASSERT(ne3 == ne13);
+
+    GGML_ASSERT(ne2 % ne02 == 0);
+    GGML_ASSERT(ne3 % ne03 == 0);
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == sizeof(float));
@@ -7797,6 +7800,10 @@ static void ggml_compute_forward_out_prod_f32(
     const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
     const int64_t blck_1 = 16;
 
+    // dps == dst per src0, used for group query attention
+    const int64_t dps2 = ne2 / ne02;
+    const int64_t dps3 = ne3 / ne03;
+
     for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
         const int64_t bir1 = MIN(bir + blck_1, ir1);
         for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
@@ -7807,8 +7814,8 @@ static void ggml_compute_forward_out_prod_f32(
                 const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
                 const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
 
-                const int64_t i02 = i2;
-                const int64_t i03 = i3;
+                const int64_t i02 = i2 / dps2;
+                const int64_t i03 = i3 / dps3;
 
                 //const int64_t i10 = i1;
                 const int64_t i12 = i2;
@@ -8906,9 +8913,9 @@ static void ggml_compute_forward_soft_max(
 }
 
 
-// ggml_compute_forward_soft_max_back
+// ggml_compute_forward_soft_max_ext_back
 
-static void ggml_compute_forward_soft_max_back_f32(
+static void ggml_compute_forward_soft_max_ext_back_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
@@ -8921,6 +8928,14 @@ static void ggml_compute_forward_soft_max_back_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_are_same_shape(src1, dst));
 
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
+
+    GGML_ASSERT(max_bias == 0.0f);
+
     // TODO: handle transposed/permuted matrices
 
     const int ith = params->ith;
@@ -8969,10 +8984,11 @@ static void ggml_compute_forward_soft_max_back_f32(
 
         // linear runtime, no additional memory
         float dot_y_dy = 0;
-        ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
-        ggml_vec_cpy_f32 (nc, dx, dy);
-        ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
-        ggml_vec_mul_f32 (nc, dx, dx, y);
+        ggml_vec_dot_f32  (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
+        ggml_vec_cpy_f32  (nc, dx, dy);
+        ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
+        ggml_vec_mul_f32  (nc, dx, dx, y);
+        ggml_vec_scale_f32(nc, dx, scale);
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
@@ -8983,7 +8999,7 @@ static void ggml_compute_forward_soft_max_back_f32(
     }
 }
 
-static void ggml_compute_forward_soft_max_back(
+static void ggml_compute_forward_soft_max_ext_back(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
@@ -8992,7 +9008,7 @@ static void ggml_compute_forward_soft_max_back(
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_soft_max_back_f32(params, dst);
+                ggml_compute_forward_soft_max_ext_back_f32(params, dst);
             } break;
         default:
             {
@@ -9985,9 +10001,10 @@ static void ggml_compute_forward_im2col_back_f32(
         const struct ggml_compute_params * params,
               struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
+    const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
 
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
@@ -10009,11 +10026,11 @@ static void ggml_compute_forward_im2col_back_f32(
     const int64_t IH = is_2D ? ne1 : 1;
     const int64_t IW = ne0;
 
-    const int64_t KH = is_2D ? ne01 : 1;
-    const int64_t KW = ne00;
+    const int64_t KH = is_2D ? ne11 : 1;
+    const int64_t KW = ne10;
 
-    const int64_t OH = is_2D ? ne12 : 1;
-    const int64_t OW = ne11;
+    const int64_t OH = is_2D ? ne02 : 1;
+    const int64_t OW = ne01;
 
     int ofs0 = is_2D ? nb3 : nb2;
     int ofs1 = is_2D ? nb2 : nb1;
@@ -10059,9 +10076,9 @@ static void ggml_compute_forward_im2col_back_f32(
                                     continue;
                                 }
 
-                                const float * const src_data = (const float *) src1->data
+                                const float * const grad_in = (const float *) src0->data
                                     + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                                grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
+                                grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
                             }
                         }
                         float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
@@ -12484,22 +12501,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * opt0 = dst->src[2];
+    const struct ggml_tensor * grad  = dst->src[0]; // gradient of forward pass output
+    const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
+    const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
 
     GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_contiguous(opt0));
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(src0f));
+    GGML_ASSERT(ggml_is_contiguous(src1f));
+    GGML_ASSERT(ggml_is_contiguous(grad));
+    GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
 
     const int64_t ith = params->ith;
     const int64_t nth = params->nth;
 
     // TODO: handle transposed/permuted matrices
-    const int64_t nc = src0->ne[0];
-    const int64_t nr = ggml_nrows(src0);
+    const int64_t nc = src0f->ne[0];
+    const int64_t nr = ggml_nrows(src0f);
 
     // rows per thread
     const int64_t dr = (nr + nth - 1)/nth;
@@ -12508,12 +12525,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
     const int64_t ir0 = dr*ith;
     const int64_t ir1 = MIN(ir0 + dr, nr);
 
-    const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
+    const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
 
     for (int64_t i1 = ir0; i1 < ir1; i1++) {
-        float * ds0 = (float *)((char *) dst->data  + i1*dst->nb[1]);
-        float * s0  = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float * s1  = (float *)((char *) src1->data + i1*src1->nb[1]);
+        float       * ds0 = (float       *)((char       *) dst->data   + i1*dst->nb[1]);
+        const float * s0  = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
+        const float * s1  = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
 
 #ifndef NDEBUG
         for (int64_t i = 0; i < nc; ++i) {
@@ -12526,11 +12543,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
         // soft_max
         float max = -INFINITY;
         ggml_vec_max_f32(nc, &max, s0);
-        ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
+        const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
         assert(sum > 0.0);
         ggml_vec_scale_f32(nc, ds0, 1.0/sum);
 
-        // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
+        // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
         ggml_vec_sub_f32(nc, ds0, ds0, s1);
         ggml_vec_scale_f32(nc, ds0, d_by_nr);
 
@@ -12827,7 +12844,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_SOFT_MAX_BACK:
             {
-                ggml_compute_forward_soft_max_back(params, tensor);
+                ggml_compute_forward_soft_max_ext_back(params, tensor);
             } break;
         case GGML_OP_ROPE:
             {
index 5c47ceb7314577abe8e8563bf4ba889329416adc..35a1c876c8631ff1bbd3219eb299aba67c16e634 100644 (file)
@@ -403,6 +403,16 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
                 op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
         case GGML_OP_MUL_MAT:
             return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
+        case GGML_OP_SOFT_MAX_BACK: {
+            if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) {
+                return false;
+            }
+            float max_bias = 0.0f;
+
+            memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
+
+            return max_bias == 0.0f;
+        }
         case GGML_OP_IM2COL_BACK:
             return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
         case GGML_OP_OUT_PROD:
index ed09406a88bacb119441e457ffa6c01c60b7f5ce..27599a2b03839919fcdc87026babaa57678d64b3 100644 (file)
@@ -5,95 +5,89 @@
 #include <cmath>
 #include <cstdint>
 
-static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
-    const int warp_id = threadIdx.x / WARP_SIZE;
-    const int lane_id = threadIdx.x % WARP_SIZE;
-    const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
-
-    const int ne_tmp = WARP_SIZE*nclasses;
-
-    extern __shared__ float tmp_all[];
-    float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
-    float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
-
-    // Each warp first loads ne_tmp logits/labels into shared memory:
-    for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
-        const int ig = i0*nclasses + i; // ig == i global
-
-        tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
-        tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
-    }
+template <bool use_shared>
+static __global__ void cross_entropy_loss_f32(
+        const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
+    extern __shared__ float tmp[];
 
-    // Each thread in the warp then calculates the cross entropy loss for a single row.
-    // TODO: pad in order to avoid shared memory bank conflicts.
+    logits += int64_t(blockIdx.x)*nclasses;
+    labels += int64_t(blockIdx.x)*nclasses;
 
     // Find maximum for softmax:
-    float max = -INFINITY;
-    for (int i = 0; i < nclasses; ++i) {
-        max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
+    float max_logit = -INFINITY;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float val = logits[i];
+        max_logit = fmaxf(max_logit, val);
+
+        if (use_shared) {
+            tmp[i] = val;
+        }
     }
+    max_logit = warp_reduce_max(max_logit);
 
     // Calculate log(softmax(logits)) which is just logits - max:
     float sum = 0.0f;
-    for (int i = 0; i < nclasses; ++i) {
-        float val = tmp_logits[lane_id*nclasses + i] - max;
-        sum += expf(val);
-        tmp_logits[lane_id*nclasses + i] = val;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float logit_i = use_shared ? tmp[i] : logits[i];
+        sum += expf(logit_i - max_logit);
     }
+    sum = warp_reduce_sum(sum);
     sum = logf(sum);
 
     // log(exp(logits - max) / sum) = (logits - max) - log(sum)
     float loss = 0.0f;
-    for (int i = 0; i < nclasses; ++i) {
-        loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float logit_i = use_shared ? tmp[i] : logits[i];
+        loss += (logit_i - max_logit - sum) * labels[i];
     }
     loss = -warp_reduce_sum(loss) / (float)k;
 
-    __syncthreads();
-
-    if (lane_id == 0) {
-        tmp_all[warp_id] = loss;
-    }
-
-    __syncthreads();
-
-    if (warp_id != 0) {
-        return;
-    }
-
-    loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
-    loss = warp_reduce_sum(loss);
-
-    if (lane_id != 0) {
+    if (threadIdx.x != 0) {
         return;
     }
 
     dst[blockIdx.x] = loss;
 }
 
-static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
+template <bool use_shared>
+static __global__ void cross_entropy_loss_back_f32(
+        const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,
+        float * __restrict__ dst, const int nclasses) {
     extern __shared__ float tmp[];
 
+    logits += int64_t(blockIdx.x)*nclasses;
+    labels += int64_t(blockIdx.x)*nclasses;
+    dst    += int64_t(blockIdx.x)*nclasses;
+
     float maxval = -INFINITY;
     for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
-        const float val = logits[blockIdx.x*nclasses + i];
+        const float val = logits[i];
         maxval = fmaxf(maxval, val);
-        tmp[i] = val;
+
+        if (use_shared) {
+            tmp[i] = val;
+        }
     }
     maxval = warp_reduce_max(maxval);
 
     float sum = 0.0f;
     for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
-        const float val = expf(tmp[i] - maxval);
+        const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval);
         sum += val;
-        tmp[i] = val;
+
+        if (use_shared) {
+            tmp[i] = val;
+        } else {
+            dst[i] = val;
+        }
     }
     sum = warp_reduce_sum(sum);
     const float sm_scale = 1.0f/sum;
 
-    const float d_by_nrows = *loss/gridDim.x;
+    const float d_by_nrows = *grad/gridDim.x;
     for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
-        dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
+        const float val = use_shared ? tmp[i] : dst[i];
+        dst[i] = (val*sm_scale - labels[i])*d_by_nrows;
     }
 }
 
@@ -119,48 +113,77 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t stream = ctx.stream();
 
-    const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
-    const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
-    const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
+    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const dim3 blocks_num(nrows, 1, 1);
+    const size_t nbytes_shared = ne00*sizeof(float);
+
+    const int    id    = ggml_cuda_get_device();
+    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
 
     ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
 
-    cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+    if (nbytes_shared <= smpbo) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+        if (!shared_memory_limit_raised[id]) {
+            CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
+            shared_memory_limit_raised[id] = true;
+        }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+        cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+    } else {
+        cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+    }
+    CUDA_CHECK(cudaGetLastError());
 
     // Combine results from individual blocks:
     sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
 }
 
 void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const ggml_tensor * src1 = dst->src[1];
-    const ggml_tensor * opt0 = dst->src[2];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(opt0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_contiguous(opt0));
+    const ggml_tensor * grad  = dst->src[0];
+    const ggml_tensor * src0f = dst->src[1];
+    const ggml_tensor * src1f = dst->src[2];
+
+    GGML_ASSERT(src0f->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1f->type == GGML_TYPE_F32);
+    GGML_ASSERT( grad->type == GGML_TYPE_F32);
+    GGML_ASSERT(  dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_scalar(grad));
+    GGML_ASSERT(ggml_is_contiguous(src0f));
+    GGML_ASSERT(ggml_is_contiguous(src1f));
     GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_are_same_shape(src0, src1));
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_are_same_shape(src0f, src1f));
+    GGML_ASSERT(ggml_are_same_shape(src0f, dst));
 
-    const int64_t ne00  = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+    const int64_t ne00  = src0f->ne[0];
+    const int64_t nrows = ggml_nrows(src0f);
 
-    const float * src0_d = (const float *) src0->data;
-    const float * src1_d = (const float *) src1->data;
-    const float * opt0_d = (const float *) opt0->data;
-    float       * dst_d  = (float       *) dst->data;
+    const float * grad_d  = (const float *) grad->data;
+    const float * src0f_d = (const float *) src0f->data;
+    const float * src1f_d = (const float *) src1f->data;
+    float       * dst_d   = (float       *) dst->data;
 
     cudaStream_t stream = ctx.stream();
 
     const dim3 blocks_dim(WARP_SIZE, 1, 1);
     const dim3 blocks_num(nrows, 1, 1);
-    const int shmem = ne00*sizeof(float);
-
-    cross_entropy_loss_back_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, opt0_d, dst_d, ne00);
+    const size_t nbytes_shared = ne00*sizeof(float);
+
+    const int    id    = ggml_cuda_get_device();
+    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+
+    if (nbytes_shared <= smpbo) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+        if (!shared_memory_limit_raised[id]) {
+            CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
+            shared_memory_limit_raised[id] = true;
+        }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+        cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
+    } else {
+        cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
+    }
 }
index 4c3703238cb6eb2922cc6b0682596fc6a19a7f37..4cef53a98cfd6acf87175eba687eb1ea16e89e2b 100644 (file)
@@ -3,15 +3,15 @@
 
 template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 static __global__ void k_get_rows(
-            const void * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+        const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
+        const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
+        /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
+        /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
+        /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
+        const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 
     const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
-    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i10 =  blockDim.y*blockIdx.y + threadIdx.y;
     const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
     const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
 
@@ -22,10 +22,10 @@ static __global__ void k_get_rows(
     const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
     dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
+    const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
 
-    const int ib = i00/qk; // block index
-    const int iqs = (i00%qk)/qr; // quant index
+    const int ib   =  i00/qk;      // block index
+    const int iqs  = (i00%qk)/qr;  // quant index
     const int iybs = i00 - i00%qk; // dst block start index
     const int y_offset = qr == 1 ? 1 : qk/2;
 
@@ -39,15 +39,15 @@ static __global__ void k_get_rows(
 
 template<typename src0_t, typename dst_t>
 static __global__ void k_get_rows_float(
-            const src0_t * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
-
-    const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
-    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+        const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
+        const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
+        /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
+        /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
+        /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
+        const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
+
+    const int i00 =  blockIdx.x*blockDim.x + threadIdx.x;
+    const int i10 =  blockDim.y*blockIdx.y + threadIdx.y;
     const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
     const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
 
@@ -58,14 +58,38 @@ static __global__ void k_get_rows_float(
     const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
     dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+    const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
 
     dst_row[i00] = src0_row[i00];
 }
 
+template<typename grad_t, typename dst_t>
+static __global__ void k_get_rows_back_float(
+        const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
+    const int col = blockIdx.x*blockDim.x + threadIdx.x;
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
+
+    float sum = 0.0f;
+
+    for (int64_t i = 0; i < nrows_grad; ++i) {
+        if (rows[i] != dst_row) {
+            continue;
+        }
+        sum += grad[i*ncols + col];
+    }
+
+    dst[dst_row*ncols + col] = sum;
+}
+
 template<int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-                            const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+static void get_rows_cuda(
+        const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+        const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
@@ -87,22 +111,25 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg
     GGML_ASSERT(ne00 % 2 == 0);
 
     k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
-            src0_dd, src1_dd, dst_dd,
-            ne00, /*ne01, ne02, ne03,*/
-            /*ne10, ne11,*/ ne12, /*ne13,*/
-            /* s0,*/ s1, s2, s3,
-            /* nb00,*/ nb01, nb02, nb03,
-            s10, s11, s12/*, s13*/);
+        src0_dd, src1_dd, dst_dd,
+        ne00, /*ne01, ne02, ne03,*/
+        /*ne10, ne11,*/ ne12, /*ne13,*/
+        /* s0,*/ s1, s2, s3,
+        /* nb00,*/ nb01, nb02, nb03,
+        s10, s11, s12/*, s13*/);
 
     GGML_UNUSED(dst);
 }
 
 template<typename src0_t>
-static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-                                const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+static void get_rows_cuda_float(
+        const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+        const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
+    GGML_ASSERT(ne13 == 1);
+
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
     const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
     const dim3 block_nums(block_num_x, ne10, ne11*ne12);
@@ -119,12 +146,12 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
     //const size_t s13 = nb13 / ggml_element_size(src1);
 
     k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
-            src0_dd, src1_dd, dst_dd,
-            ne00, /*ne01, ne02, ne03,*/
-            /*ne10, ne11,*/ ne12, /*ne13,*/
-            /* s0,*/ s1, s2, s3,
-            /* nb00,*/ nb01, nb02, nb03,
-            s10, s11, s12/*, s13*/);
+        src0_dd, src1_dd, dst_dd,
+        ne00, /*ne01, ne02, ne03,*/
+        /*ne10, ne11,*/ ne12, /*ne13,*/
+        /* s0,*/ s1, s2, s3,
+        /* nb00,*/ nb01, nb02, nb03,
+        s10, s11, s12/*, s13*/);
 
     GGML_UNUSED(dst);
 }
@@ -132,42 +159,41 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
 void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
-    const float * src0_d = (const float *)src0->data;
-    const float * src1_d = (const float *)src1->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
 
+    const void    * src0_d = (const void    *) src0->data;
+    const int32_t * src1_d = (const int32_t *) src1->data;
+    float         * dst_d  = (float         *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 
     GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
     GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
-    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
-
-    const int32_t * src1_i32 = (const int32_t *) src1_d;
+    GGML_ASSERT(dst->nb[0]  == ggml_type_size(dst->type));
 
     switch (src0->type) {
         case GGML_TYPE_F16:
-            get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_F32:
-            get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q4_0:
-            get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q4_1:
-            get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q5_0:
-            get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q5_1:
-            get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q8_0:
-            get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         default:
             // TODO: k-quants
@@ -175,3 +201,34 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
             break;
     }
 }
+
+void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
+    const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const float   * src0_d = (const float   *) src0->data;
+    const int32_t * src1_d = (const int32_t *) src1->data;
+    float         * dst_d  = (float         *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    GGML_ASSERT(ne02*ne03 == 1);
+    GGML_ASSERT(ne12*ne13 == 1);
+    GGML_ASSERT(ne2*ne3 == 1);
+
+    const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
+    const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
+    const dim3 block_nums(block_num_x, ne1, 1);
+
+    k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
+}
index bbf1302325ce4d01c5ae70f5a727eddd66954acd..a1ca643f1c5300d1e2f91f111d7f3f76be4abde3 100644 (file)
@@ -1,5 +1,8 @@
 #include "common.cuh"
 
 #define CUDA_GET_ROWS_BLOCK_SIZE 256
+#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
 
 void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 9118edc7282268c6ba7abbece154eab99019241a..7fd1fc85346f1cff40065ada241db9f003fdbeae 100644 (file)
@@ -2003,6 +2003,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_GET_ROWS:
             ggml_cuda_op_get_rows(ctx, dst);
             break;
+        case GGML_OP_GET_ROWS_BACK:
+            ggml_cuda_op_get_rows_back(ctx, dst);
+            break;
         case GGML_OP_DUP:
             ggml_cuda_dup(ctx, dst);
             break;
@@ -2091,9 +2094,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_LEAKY_RELU:
             ggml_cuda_op_leaky_relu(ctx, dst);
             break;
+        case GGML_OP_SILU_BACK:
+            ggml_cuda_op_silu_back(ctx, dst);
+            break;
         case GGML_OP_RMS_NORM:
             ggml_cuda_op_rms_norm(ctx, dst);
             break;
+        case GGML_OP_RMS_NORM_BACK:
+            ggml_cuda_op_rms_norm_back(ctx, dst);
+            break;
         case GGML_OP_MUL_MAT:
             if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
                 GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
@@ -2138,6 +2147,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SOFT_MAX:
             ggml_cuda_op_soft_max(ctx, dst);
             break;
+        case GGML_OP_SOFT_MAX_BACK:
+            ggml_cuda_op_soft_max_back(ctx, dst);
+            break;
         case GGML_OP_ROPE:
             ggml_cuda_op_rope(ctx, dst);
             break;
@@ -2912,7 +2924,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 }
             } break;
         case GGML_OP_OUT_PROD:
-            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
+            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
         case GGML_OP_GET_ROWS:
             {
                 switch (op->src[0]->type) {
@@ -2928,6 +2940,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                         return false;
                 }
             } break;
+        case GGML_OP_GET_ROWS_BACK:
+            {
+                return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
+            } break;
         case GGML_OP_CPY:
             {
                 ggml_type src0_type = op->src[0]->type;
@@ -3001,8 +3017,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 }
                 return false;
             } break;
+        case GGML_OP_SILU_BACK:
+            return ggml_is_contiguous(op->src[0]);
+            break;
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
+        case GGML_OP_RMS_NORM_BACK:
             return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
             break;
         case GGML_OP_NONE:
@@ -3027,6 +3047,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
             return true;
+        case GGML_OP_SOFT_MAX_BACK: {
+            float max_bias = 0.0f;
+            memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
+            return max_bias == 0.0f;
+        }
         case GGML_OP_ROPE:
         case GGML_OP_ROPE_BACK: {
             const size_t ts = ggml_type_size(op->src[0]->type);
index 133e219f0aeda890bd2de2026c424e882e6643e8..d991ec972813ffb583e89944898a113bc1c3548d 100644 (file)
@@ -5,20 +5,24 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
-    float2 mean_var = make_float2(0.f, 0.f);
+    x   += int64_t(row)*ncols;
+    dst += int64_t(row)*ncols;
+
+    float2 mean_var = make_float2(0.0f, 0.0f);
 
     for (int col = tid; col < ncols; col += block_size) {
-        const float xi = x[row*ncols + col];
+        const float xi = x[col];
         mean_var.x += xi;
         mean_var.y += xi * xi;
     }
 
     // sum up partial sums
     mean_var = warp_reduce_sum(mean_var);
-    if (block_size > WARP_SIZE) {
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
         __shared__ float2 s_sum[32];
-        int warp_id = threadIdx.x / WARP_SIZE;
-        int lane_id = threadIdx.x % WARP_SIZE;
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = mean_var;
         }
@@ -32,7 +36,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
     const float inv_std = rsqrtf(var + eps);
 
     for (int col = tid; col < ncols; col += block_size) {
-        dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
+        dst[col] = (x[col] - mean) * inv_std;
     }
 }
 
@@ -40,14 +44,8 @@ template <int block_size>
 static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
     // blockIdx.x: num_groups idx
     // threadIdx.x: block_size idx
-    int start = blockIdx.x * group_size;
-    int end = start + group_size;
-
-    start += threadIdx.x;
-
-    if (end >= ne_elements) {
-        end = ne_elements;
-    }
+    const int start =     blockIdx.x*group_size + threadIdx.x;
+    const int end   = min(blockIdx.x*group_size + group_size,  ne_elements);
 
     float tmp = 0.0f; // partial sum for thread in warp
 
@@ -56,10 +54,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
     }
 
     tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
         __shared__ float s_sum[32];
-        int warp_id = threadIdx.x / WARP_SIZE;
-        int lane_id = threadIdx.x % WARP_SIZE;
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = tmp;
         }
@@ -68,11 +67,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
         tmp = warp_reduce_sum(tmp);
     }
 
-    float mean = tmp / group_size;
+    const float mean = tmp / group_size;
     tmp = 0.0f;
 
     for (int j = start; j < end; j += block_size) {
-        float xi = x[j] - mean;
+        const float xi = x[j] - mean;
         dst[j] = xi;
         tmp += xi * xi;
     }
@@ -80,8 +79,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
     tmp = warp_reduce_sum(tmp);
     if (block_size > WARP_SIZE) {
         __shared__ float s_sum[32];
-        int warp_id = threadIdx.x / WARP_SIZE;
-        int lane_id = threadIdx.x % WARP_SIZE;
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = tmp;
         }
@@ -90,8 +89,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
         tmp = warp_reduce_sum(tmp);
     }
 
-    float variance = tmp / group_size;
-    float scale = rsqrtf(variance + eps);
+    const float variance = tmp / group_size;
+    const float scale = rsqrtf(variance + eps);
     for (int j = start; j < end; j += block_size) {
         dst[j] *= scale;
     }
@@ -102,19 +101,23 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
+    x   += int64_t(row)*ncols;
+    dst += int64_t(row)*ncols;
+
     float tmp = 0.0f; // partial sum for thread in warp
 
     for (int col = tid; col < ncols; col += block_size) {
-        const float xi = x[row*ncols + col];
+        const float xi = x[col];
         tmp += xi * xi;
     }
 
     // sum up partial sums
     tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
         __shared__ float s_sum[32];
-        int warp_id = threadIdx.x / WARP_SIZE;
-        int lane_id = threadIdx.x % WARP_SIZE;
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = tmp;
         }
@@ -127,12 +130,63 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
     const float scale = rsqrtf(mean + eps);
 
     for (int col = tid; col < ncols; col += block_size) {
-        dst[row*ncols + col] = scale * x[row*ncols + col];
+        dst[col] = scale * x[col];
+    }
+}
+
+template <int block_size>
+static __global__ void rms_norm_back_f32(
+        const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int tid = threadIdx.x;
+
+    grad += int64_t(row)*ncols;
+    xf   += int64_t(row)*ncols;
+    dst  += int64_t(row)*ncols;
+
+    float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
+    float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
+
+    for (int col = tid; col < ncols; col += block_size) {
+        const float xfi = xf[col];
+        sum_xx += xfi * xfi;
+        sum_xg += xfi * grad[col];
+    }
+
+    // sum up partial sums
+    sum_xx = warp_reduce_sum(sum_xx);
+    sum_xg = warp_reduce_sum(sum_xg);
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
+        __shared__ float s_sum_xx[32];
+        __shared__ float s_sum_xg[32];
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum_xx[warp_id] = sum_xx;
+            s_sum_xg[warp_id] = sum_xg;
+        }
+        __syncthreads();
+
+        sum_xx = s_sum_xx[lane_id];
+        sum_xx = warp_reduce_sum(sum_xx);
+
+        sum_xg = s_sum_xg[lane_id];
+        sum_xg = warp_reduce_sum(sum_xg);
+    }
+
+    const float mean_eps = sum_xx / ncols + eps;
+    const float sum_eps  = sum_xx + ncols*eps;
+
+    const float scale_grad = rsqrtf(mean_eps);
+    const float scale_x    = -scale_grad * sum_xg/sum_eps;
+
+    for (int col = tid; col < ncols; col += block_size) {
+        dst[col] = scale_grad*grad[col] + scale_x*xf[col];
     }
 }
 
 static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
         norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
@@ -142,7 +196,8 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
     }
 }
 
-static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
+static void group_norm_f32_cuda(
+        const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
     if (group_size < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
         group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
@@ -153,7 +208,6 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
 }
 
 static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
         rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
@@ -163,6 +217,16 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
     }
 }
 
+static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        rms_norm_back_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        rms_norm_back_f32<1024><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
+    }
+}
+
 void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
@@ -179,6 +243,7 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
 
     norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
 }
@@ -198,6 +263,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 
     float eps;
     memcpy(&eps, dst->op_params + 1, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
 
     int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
     group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
@@ -219,6 +285,33 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
 
     rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
 }
+
+void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * grad  = dst->src[0]; // gradients
+    const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
+
+    const float * grad_d  = (const float *) grad->data;
+    const float * src0f_d = (const float *) src0f->data;
+    float       * dst_d   = (float       *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(grad));
+
+    GGML_ASSERT( grad->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0f->type == GGML_TYPE_F32);
+    GGML_ASSERT(  dst->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0f->ne[0];
+    const int64_t nrows = ggml_nrows(src0f);
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
+
+    rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
+}
index 431a8f74d55c75e9ba3b691ae3cd7007610ae85f..d63d34380b0a709eda1c73bdf4a51f87e06f77e6 100644 (file)
@@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 619cfdcb5894ad79d977c0f4fee3582558bf0b4a..73e3e2c47f28a8801c38b730748432618f926f94 100644 (file)
@@ -11,16 +11,15 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type  == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
 
     GGML_ASSERT(ne01 == ne11);
     GGML_ASSERT(ne0 == ne00);
     GGML_ASSERT(ne1 == ne10);
 
-    GGML_ASSERT(ne2 == src0->ne[2]);
+    GGML_ASSERT(ne2 % src0->ne[2] == 0);
+    GGML_ASSERT(ne3 % src0->ne[3] == 0);
+
     GGML_ASSERT(ne2 == src1->ne[2]);
-    GGML_ASSERT(ne3 == src0->ne[3]);
     GGML_ASSERT(ne3 == src1->ne[3]);
 
     const float * src0_d = (const float *) src0->data;
@@ -33,8 +32,6 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const float alpha = 1.0f;
     const float beta = 0.0f;
 
-    GGML_ASSERT(ne2 == 1);
-    GGML_ASSERT(ne3 == 1);
     CUBLAS_CHECK(cublasSetStream(handle, stream));
 
     const bool src1_T = ggml_is_transposed(src1);
@@ -42,10 +39,27 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t           ldb            = (src1_T ?        nb10 :        nb11) /  sizeof(float);
     GGML_ASSERT(                             (src1_T ?        nb11 :        nb10) == sizeof(float));
 
-    CUBLAS_CHECK(
-        cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
-                ne0, ne1, ne01,
-                &alpha, src0_d, ne00,
-                        src1_d, ldb,
-                &beta,  dst_d,  ne0));
+    // data strides in dimensions 2/3
+    const size_t s02 = nb02 / sizeof(float);
+    const size_t s03 = nb03 / sizeof(float);
+    const size_t s12 = nb12 / sizeof(float);
+    const size_t s13 = nb13 / sizeof(float);
+    const size_t s2  = nb2  / sizeof(float);
+    const size_t s3  = nb3  / sizeof(float);
+
+    // dps == dst per src0, used for group query attention
+    const int64_t dps2 = ne2 / ne02;
+    const int64_t dps3 = ne3 / ne03;
+
+    // TODO batched matrix multiplication
+    for (int64_t i3 = 0; i3 < ne3; ++i3) {
+        for (int64_t i2 = 0; i2 < ne2; ++i2) {
+            CUBLAS_CHECK(
+                cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
+                        ne0, ne1, ne01,
+                        &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
+                                src1_d +  i3      *s13 +  i2      *s12, ldb,
+                        &beta,  dst_d  +  i3      *s3  +  i2      *s2,  ne0));
+        }
+    }
 }
index e1912fee1f9abd35657f2b1530939903be34447d..18f691b2d3103fce13cd5cd60f522798157d91e6 100644 (file)
@@ -39,9 +39,9 @@ static __device__ void rope_yarn(
 
 template<bool forward, bool has_ff, typename T>
 static __global__ void rope_norm(
-        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
-        const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
+        const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+        const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
@@ -83,9 +83,9 @@ static __global__ void rope_norm(
 
 template<bool forward, bool has_ff, typename T>
 static __global__ void rope_neox(
-        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
-        const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
+        const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+        const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
@@ -127,9 +127,9 @@ static __global__ void rope_neox(
 
 template<bool forward, bool has_ff, typename T>
 static __global__ void rope_multi(
-        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
-        const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
+        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
+        const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
@@ -187,9 +187,9 @@ static __global__ void rope_multi(
 
 template<bool forward, bool has_ff, typename T>
 static __global__ void rope_vision(
-        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
-        const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
-        const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
+        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
+        const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
+        const float theta_scale, const float * freq_factors, const mrope_sections sections) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
@@ -234,9 +234,9 @@ static __global__ void rope_vision(
 
 template<bool forward, typename T>
 static void rope_norm_cuda(
-        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
-        const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
+        const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -257,9 +257,9 @@ static void rope_norm_cuda(
 
 template<bool forward, typename T>
 static void rope_neox_cuda(
-        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
-        const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
+        const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -280,9 +280,9 @@ static void rope_neox_cuda(
 
 template<bool forward, typename T>
 static void rope_multi_cuda(
-        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
-        const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
+        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -303,9 +303,9 @@ static void rope_multi_cuda(
 
 template<bool forward, typename T>
 static void rope_vision_cuda(
-        const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
-        const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
+        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
index c24abae1f138c6fd6d5999f5f1d09139fd0b4eba..9aa4b84893855aedfe1caa2756584ff4f35844eb 100644 (file)
@@ -1,5 +1,7 @@
 #include "common.cuh"
+#include "ggml.h"
 #include "softmax.cuh"
+#include <cstdint>
 
 template <typename T>
 static __device__ __forceinline__ float t2f32(T val) {
@@ -11,14 +13,20 @@ __device__ float __forceinline__ t2f32<half>(half val) {
     return __half2float(val);
 }
 
-template <bool vals_smem, int ncols_template, int block_size_template, typename T>
-static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+template <bool use_shared, int ncols_template, int block_size_template, typename T>
+static __global__ void soft_max_f32(
+        const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
+        const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
     const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
 
     const int tid  = threadIdx.x;
     const int rowx = blockIdx.x;
     const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
 
+    x    += int64_t(rowx)*ncols;
+    mask += int64_t(rowy)*ncols * (mask != nullptr);
+    dst  += int64_t(rowx)*ncols;
+
     const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
 
     const int warp_id = threadIdx.x / WARP_SIZE;
@@ -29,7 +37,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
     extern __shared__ float data_soft_max_f32[];
     float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
     // shared memory buffer to cache values between iterations:
-    float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
+    float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
 
     float max_val = -INFINITY;
 
@@ -41,10 +49,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
             break;
         }
 
-        const int64_t ix = (int64_t)rowx*ncols + col;
-        const int64_t iy = (int64_t)rowy*ncols + col;
-
-        const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
+        const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
 
         vals[col] = val;
         max_val = max(max_val, val);
@@ -110,8 +115,29 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
             return;
         }
 
-        const int64_t idst = (int64_t)rowx*ncols + col;
-        dst[idst] = vals[col] * inv_sum;
+        dst[col] = vals[col] * inv_sum;
+    }
+}
+
+static __global__ void soft_max_back_f32(
+        const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
+    const int tid  = threadIdx.x;
+    const int rowx = blockIdx.x;
+
+    grad += int64_t(rowx)*ncols;
+    dstf += int64_t(rowx)*ncols;
+    dst  += int64_t(rowx)*ncols;
+
+    float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
+
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
+        dgf_dot += dstf[col]*grad[col];
+    }
+
+    dgf_dot = warp_reduce_sum(dgf_dot);
+
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
+        dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
     }
 }
 
@@ -121,7 +147,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
     while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
     const dim3 block_dims(nth,     1, 1);
     const dim3 block_nums(nrows_x, 1, 1);
-    const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
+    const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
     static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
 
     const uint32_t n_head      = nrows_x/nrows_y;
@@ -131,50 +157,68 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
     // FIXME: this limit could be raised by ~2-4x on Ampere or newer
-    if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
+    if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
         switch (ncols_x) {
             case 32:
-                soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true,   32,   32><<<block_nums, block_dims, nbytes_shared, stream>>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 64:
-                soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true,   64,   64><<<block_nums, block_dims, nbytes_shared, stream>>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 128:
-                soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true,  128,  128><<<block_nums, block_dims, nbytes_shared, stream>>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 256:
-                soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true,  256,  256><<<block_nums, block_dims, nbytes_shared, stream>>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 512:
-                soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true,  512,  512><<<block_nums, block_dims, nbytes_shared, stream>>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 1024:
-                soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 2048:
-                soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 4096:
-                soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             default:
-                soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true,    0,    0><<<block_nums, block_dims, nbytes_shared, stream>>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
         }
     } else {
-        const size_t shmem_low = WARP_SIZE*sizeof(float);
-        soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+        const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
+        soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
     }
 }
 
+static void soft_max_back_f32_cuda(
+        const float * grad, const float * dstf, float * dst,
+        const int ncols, const int nrows, const float scale, cudaStream_t stream) {
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums(nrows,     1, 1);
+
+    soft_max_back_f32<<<block_nums, block_dims, 0, stream>>>(grad, dstf, dst, ncols, scale);
+}
+
 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
 
-    const float * src0_d = (const float *)src0->data;
-    const void  * src1_d = src1 ? (const void *)src1->data : nullptr;
+    const float * src0_d = (const float *) src0->data;
+    const void  * src1_d = src1 ? (const void *) src1->data : nullptr;
+    float       *  dst_d = (float *) dst->data;
 
-    float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -189,18 +233,42 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float scale    = 1.0f;
     float max_bias = 0.0f;
 
-    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
 
     const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 
     if (use_f16) {
-        const half * src1_dd = (const half *)src1_d;
-
-        soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+        soft_max_f32_cuda(src0_d, (const half  *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
     } else {
-        const float * src1_dd = (const float *)src1_d;
-
-        soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+        soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
     }
 }
+
+void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0]; // grad
+    const ggml_tensor * src1 = dst->src[1]; // forward pass output
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       * dst_d  = (float       *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
+
+    GGML_ASSERT(max_bias == 0.0f);
+
+    soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
+}
index 4ef4ff86c9c8df1165f83d3642a2e7986dc06eef..93dfee835f6ff05849a7865c5fc8f3efd87f4e7f 100644 (file)
@@ -3,3 +3,5 @@
 #define CUDA_SOFT_MAX_BLOCK_SIZE 1024
 
 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 81fc92202f25ae45a5b21669e1cb8835e33e787d..6b21f407d80490086d08aa50c68299e4fd13316a 100644 (file)
@@ -51,6 +51,19 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __global__ void silu_back_f32(
+        const float * grad, const float * xf, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    const float xfi = xf[i];
+    const float s = 1.0f / (1.0f + expf(-xfi));
+    dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
+}
+
 static __global__ void tanh_f32(const float * x, float * dst, int k) {
     const int i  = blockDim.x*blockIdx.x + threadIdx.x;
     if (i >= k) {
@@ -173,6 +186,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
+    silu_back_f32<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
+}
+
 static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
     tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -284,6 +302,24 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
 }
 
+void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0]; // input from forward pass
+    const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       * dst_d  = (float       *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
+}
+
 void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
index c91936728bab16645732822b8a5839b4667f6b76..e7f62643a2afd65441b014732690761f3c6f9eeb 100644 (file)
@@ -4,6 +4,7 @@
 #define CUDA_STEP_BLOCK_SIZE 256
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_SILU_BACK_BLOCK_SIZE 256
 #define CUDA_TANH_BLOCK_SIZE 256
 #define CUDA_RELU_BLOCK_SIZE 256
 #define CUDA_SIGMOID_BLOCK_SIZE 256
@@ -23,6 +24,8 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 0a0dcc7e3b1da600a6610dff15ae0171882d14d1..d83b1b8ad6be88e000a4f8e9746bfb0dbdbaccb7 100644 (file)
@@ -3454,12 +3454,14 @@ struct ggml_tensor * ggml_soft_max_ext(
     return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
 }
 
-// ggml_soft_max_back
+// ggml_soft_max_ext_back
 
-static struct ggml_tensor * ggml_soft_max_back_impl(
+static struct ggml_tensor * ggml_soft_max_ext_back_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
+        float                 scale,
+        float                 max_bias,
         bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
@@ -3467,21 +3469,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl(
     result->src[0] = a;
     result->src[1] = b;
 
+    memcpy((float *) result->op_params + 0, &scale,    sizeof(float));
+    memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
+
     return result;
 }
 
-struct ggml_tensor * ggml_soft_max_back(
+struct ggml_tensor * ggml_soft_max_ext_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        struct ggml_tensor  * b) {
-    return ggml_soft_max_back_impl(ctx, a, b, false);
+        struct ggml_tensor  * b,
+        float                 scale,
+        float                 max_bias) {
+    return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
 }
 
-struct ggml_tensor * ggml_soft_max_back_inplace(
+struct ggml_tensor * ggml_soft_max_ext_back_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        struct ggml_tensor  * b) {
-    return ggml_soft_max_back_impl(ctx, a, b, true);
+        struct ggml_tensor  * b,
+        float                 scale,
+        float                 max_bias) {
+    return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
 }
 
 // ggml_rope
@@ -5080,10 +5089,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         struct ggml_tensor  * c) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
-    GGML_ASSERT(ggml_is_scalar(c));
+    GGML_ASSERT(ggml_is_scalar(a));
+    GGML_ASSERT(ggml_are_same_shape(b, c));
 
-    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
 
     result->op     = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
     result->src[0] = a;
@@ -5262,7 +5271,7 @@ static void ggml_sub_or_set(
 }
 
 static void ggml_compute_backward(
-        struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
+        struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
     struct ggml_tensor * tensor = cgraph->nodes[i];
     struct ggml_tensor * grad   = ggml_graph_get_grad(cgraph, tensor);
 
@@ -5406,7 +5415,7 @@ static void ggml_compute_backward(
             if (src0_needs_grads) {
                 float eps;
                 memcpy(&eps, tensor->op_params, sizeof(float));
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
             }
         } break;
         case GGML_OP_MUL_MAT: {
@@ -5589,7 +5598,13 @@ static void ggml_compute_backward(
         } break;
         case GGML_OP_SOFT_MAX: {
             if (src0_needs_grads) {
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
+                float scale    = 1.0f;
+                float max_bias = 0.0f;
+
+                memcpy(&scale,    (const float *) tensor->op_params + 0, sizeof(float));
+                memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
+
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
             }
             GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
         } break;
@@ -5630,7 +5645,7 @@ static void ggml_compute_backward(
                 const int32_t d1    = ggml_get_op_params_i32(tensor, 5);
                 const bool    is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
 
-                ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
+                ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
             }
         } break;
         case GGML_OP_POOL_2D: {
@@ -5673,7 +5688,7 @@ static void ggml_compute_backward(
                 } break;
                 case GGML_UNARY_OP_SILU: {
                     if (src0_needs_grads) {
-                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
                     }
                 } break;
                 case GGML_UNARY_OP_EXP: {
@@ -5690,7 +5705,7 @@ static void ggml_compute_backward(
         } break;
         case GGML_OP_CROSS_ENTROPY_LOSS: {
             if (src0_needs_grads) {
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
             }
             GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
         } break;