float scale,
float max_bias);
- GGML_API struct ggml_tensor * ggml_soft_max_back(
+ GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
- struct ggml_tensor * b);
+ struct ggml_tensor * b,
+ float scale,
+ float max_bias);
// in-place, returns view(a)
- GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
+ GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
- struct ggml_tensor * b);
+ struct ggml_tensor * b,
+ float scale,
+ float max_bias);
// rotary position embedding
// if (mode & 1) - skip n_past elements (NOT SUPPORTED)
return true;
}
+// ops that return true for this function must not use restrict pointers for their backend implementations
static bool ggml_op_can_inplace(enum ggml_op op) {
switch (op) {
case GGML_OP_SCALE:
case GGML_OP_LOG:
case GGML_OP_UNARY:
case GGML_OP_ROPE:
+ case GGML_OP_ROPE_BACK:
+ case GGML_OP_SILU_BACK:
case GGML_OP_RMS_NORM:
+ case GGML_OP_RMS_NORM_BACK:
case GGML_OP_SOFT_MAX:
+ case GGML_OP_SOFT_MAX_BACK:
return true;
default:
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * grad = dst->src[1];
+ const struct ggml_tensor * grad = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
assert(ggml_is_contiguous_1(grad));
- assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(src1));
assert(ggml_is_contiguous_1(dst));
- assert(ggml_are_same_shape(src0, dst));
- assert(ggml_are_same_shape(src0, grad));
+ assert(ggml_are_same_shape(src1, dst));
+ assert(ggml_are_same_shape(src1, grad));
const int ith = params->ith;
const int nth = params->nth;
- const int nc = src0->ne[0];
- const int nr = ggml_nrows(src0);
+ const int nc = src1->ne[0];
+ const int nr = ggml_nrows(src1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_silu_backward_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
- (float *) ((char *) src0->data + i1*(src0->nb[1])),
+ (float *) ((char *) src1->data + i1*(src1->nb[1])),
(float *) ((char *) grad->data + i1*(grad->nb[1])));
#ifndef NDEBUG
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
- GGML_ASSERT(eps > 0.0f);
+ GGML_ASSERT(eps >= 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
- GGML_ASSERT(eps > 0.0f);
+ GGML_ASSERT(eps >= 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
+ const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
+ const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int64_t i12 = i02;
const int64_t i13 = i03;
- const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
- const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
+ const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+ const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
ggml_float sum_xx = 0.0;
ggml_float sum_xdz = 0.0;
{
// z = rms_norm(x)
//
- // rms_norm(src0) =
+ // rms_norm(src1) =
// scale(
- // src0,
+ // src1,
// div(
// 1,
// sqrt(
// scale(
// sum(
// sqr(
- // src0)),
+ // src1)),
// (1.0/N)),
// eps))));
// postorder:
// ## op args grad
- // 00 param src0 grad[#00]
+ // 00 param src1 grad[#00]
// 01 const 1
// 02 sqr (#00) grad[#02]
// 03 sum (#02) grad[#03]
// dx := scale(dx, rrms)
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+ // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
ggml_vec_cpy_f32 (ne00, dx, x);
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
const int ith = params->ith;
const int nth = params->nth;
- GGML_ASSERT(ne0 == ne00);
- GGML_ASSERT(ne1 == ne10);
- GGML_ASSERT(ne2 == ne02);
- GGML_ASSERT(ne02 == ne12);
- GGML_ASSERT(ne3 == ne13);
- GGML_ASSERT(ne03 == ne13);
+ GGML_ASSERT(ne0 == ne00);
+ GGML_ASSERT(ne1 == ne10);
+ GGML_ASSERT(ne2 == ne12);
+ GGML_ASSERT(ne3 == ne13);
+
+ GGML_ASSERT(ne2 % ne02 == 0);
+ GGML_ASSERT(ne3 % ne03 == 0);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == sizeof(float));
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
const int64_t blck_1 = 16;
+ // dps == dst per src0, used for group query attention
+ const int64_t dps2 = ne2 / ne02;
+ const int64_t dps3 = ne3 / ne03;
+
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
const int64_t bir1 = MIN(bir + blck_1, ir1);
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
- const int64_t i02 = i2;
- const int64_t i03 = i3;
+ const int64_t i02 = i2 / dps2;
+ const int64_t i03 = i3 / dps3;
//const int64_t i10 = i1;
const int64_t i12 = i2;
}
-// ggml_compute_forward_soft_max_back
+// ggml_compute_forward_soft_max_ext_back
-static void ggml_compute_forward_soft_max_back_f32(
+static void ggml_compute_forward_soft_max_ext_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src1, dst));
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
+
+ GGML_ASSERT(max_bias == 0.0f);
+
// TODO: handle transposed/permuted matrices
const int ith = params->ith;
// linear runtime, no additional memory
float dot_y_dy = 0;
- ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
- ggml_vec_cpy_f32 (nc, dx, dy);
- ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
- ggml_vec_mul_f32 (nc, dx, dx, y);
+ ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
+ ggml_vec_cpy_f32 (nc, dx, dy);
+ ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
+ ggml_vec_mul_f32 (nc, dx, dx, y);
+ ggml_vec_scale_f32(nc, dx, scale);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
}
}
-static void ggml_compute_forward_soft_max_back(
+static void ggml_compute_forward_soft_max_ext_back(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
- ggml_compute_forward_soft_max_back_f32(params, dst);
+ ggml_compute_forward_soft_max_ext_back_f32(params, dst);
} break;
default:
{
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
+ const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
+ const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t IH = is_2D ? ne1 : 1;
const int64_t IW = ne0;
- const int64_t KH = is_2D ? ne01 : 1;
- const int64_t KW = ne00;
+ const int64_t KH = is_2D ? ne11 : 1;
+ const int64_t KW = ne10;
- const int64_t OH = is_2D ? ne12 : 1;
- const int64_t OW = ne11;
+ const int64_t OH = is_2D ? ne02 : 1;
+ const int64_t OW = ne01;
int ofs0 = is_2D ? nb3 : nb2;
int ofs1 = is_2D ? nb2 : nb1;
continue;
}
- const float * const src_data = (const float *) src1->data
+ const float * const grad_in = (const float *) src0->data
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
- grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
+ grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
}
}
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
- const struct ggml_tensor * src0 = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
- const struct ggml_tensor * opt0 = dst->src[2];
+ const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
+ const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
+ const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
GGML_ASSERT(ggml_is_contiguous(dst));
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
- GGML_ASSERT(ggml_is_contiguous(opt0));
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_is_contiguous(src0f));
+ GGML_ASSERT(ggml_is_contiguous(src1f));
+ GGML_ASSERT(ggml_is_contiguous(grad));
+ GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
const int64_t ith = params->ith;
const int64_t nth = params->nth;
// TODO: handle transposed/permuted matrices
- const int64_t nc = src0->ne[0];
- const int64_t nr = ggml_nrows(src0);
+ const int64_t nc = src0f->ne[0];
+ const int64_t nr = ggml_nrows(src0f);
// rows per thread
const int64_t dr = (nr + nth - 1)/nth;
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
- const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
+ const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
for (int64_t i1 = ir0; i1 < ir1; i1++) {
- float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
- float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
- float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
+ float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
+ const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
+ const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
#ifndef NDEBUG
for (int64_t i = 0; i < nc; ++i) {
// soft_max
float max = -INFINITY;
ggml_vec_max_f32(nc, &max, s0);
- ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
+ const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
assert(sum > 0.0);
ggml_vec_scale_f32(nc, ds0, 1.0/sum);
- // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
+ // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
ggml_vec_sub_f32(nc, ds0, ds0, s1);
ggml_vec_scale_f32(nc, ds0, d_by_nr);
} break;
case GGML_OP_SOFT_MAX_BACK:
{
- ggml_compute_forward_soft_max_back(params, tensor);
+ ggml_compute_forward_soft_max_ext_back(params, tensor);
} break;
case GGML_OP_ROPE:
{
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
case GGML_OP_MUL_MAT:
return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
+ case GGML_OP_SOFT_MAX_BACK: {
+ if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) {
+ return false;
+ }
+ float max_bias = 0.0f;
+
+ memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
+
+ return max_bias == 0.0f;
+ }
case GGML_OP_IM2COL_BACK:
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
case GGML_OP_OUT_PROD:
#include <cmath>
#include <cstdint>
-static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
-
- const int ne_tmp = WARP_SIZE*nclasses;
-
- extern __shared__ float tmp_all[];
- float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
- float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
-
- // Each warp first loads ne_tmp logits/labels into shared memory:
- for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
- const int ig = i0*nclasses + i; // ig == i global
-
- tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
- tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
- }
+template <bool use_shared>
+static __global__ void cross_entropy_loss_f32(
+ const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
+ extern __shared__ float tmp[];
- // Each thread in the warp then calculates the cross entropy loss for a single row.
- // TODO: pad in order to avoid shared memory bank conflicts.
+ logits += int64_t(blockIdx.x)*nclasses;
+ labels += int64_t(blockIdx.x)*nclasses;
// Find maximum for softmax:
- float max = -INFINITY;
- for (int i = 0; i < nclasses; ++i) {
- max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
+ float max_logit = -INFINITY;
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+ const float val = logits[i];
+ max_logit = fmaxf(max_logit, val);
+
+ if (use_shared) {
+ tmp[i] = val;
+ }
}
+ max_logit = warp_reduce_max(max_logit);
// Calculate log(softmax(logits)) which is just logits - max:
float sum = 0.0f;
- for (int i = 0; i < nclasses; ++i) {
- float val = tmp_logits[lane_id*nclasses + i] - max;
- sum += expf(val);
- tmp_logits[lane_id*nclasses + i] = val;
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+ const float logit_i = use_shared ? tmp[i] : logits[i];
+ sum += expf(logit_i - max_logit);
}
+ sum = warp_reduce_sum(sum);
sum = logf(sum);
// log(exp(logits - max) / sum) = (logits - max) - log(sum)
float loss = 0.0f;
- for (int i = 0; i < nclasses; ++i) {
- loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+ const float logit_i = use_shared ? tmp[i] : logits[i];
+ loss += (logit_i - max_logit - sum) * labels[i];
}
loss = -warp_reduce_sum(loss) / (float)k;
- __syncthreads();
-
- if (lane_id == 0) {
- tmp_all[warp_id] = loss;
- }
-
- __syncthreads();
-
- if (warp_id != 0) {
- return;
- }
-
- loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
- loss = warp_reduce_sum(loss);
-
- if (lane_id != 0) {
+ if (threadIdx.x != 0) {
return;
}
dst[blockIdx.x] = loss;
}
-static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
+template <bool use_shared>
+static __global__ void cross_entropy_loss_back_f32(
+ const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,
+ float * __restrict__ dst, const int nclasses) {
extern __shared__ float tmp[];
+ logits += int64_t(blockIdx.x)*nclasses;
+ labels += int64_t(blockIdx.x)*nclasses;
+ dst += int64_t(blockIdx.x)*nclasses;
+
float maxval = -INFINITY;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
- const float val = logits[blockIdx.x*nclasses + i];
+ const float val = logits[i];
maxval = fmaxf(maxval, val);
- tmp[i] = val;
+
+ if (use_shared) {
+ tmp[i] = val;
+ }
}
maxval = warp_reduce_max(maxval);
float sum = 0.0f;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
- const float val = expf(tmp[i] - maxval);
+ const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval);
sum += val;
- tmp[i] = val;
+
+ if (use_shared) {
+ tmp[i] = val;
+ } else {
+ dst[i] = val;
+ }
}
sum = warp_reduce_sum(sum);
const float sm_scale = 1.0f/sum;
- const float d_by_nrows = *loss/gridDim.x;
+ const float d_by_nrows = *grad/gridDim.x;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
- dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
+ const float val = use_shared ? tmp[i] : dst[i];
+ dst[i] = (val*sm_scale - labels[i])*d_by_nrows;
}
}
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t stream = ctx.stream();
- const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
- const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
- const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
+ const dim3 blocks_dim(WARP_SIZE, 1, 1);
+ const dim3 blocks_num(nrows, 1, 1);
+ const size_t nbytes_shared = ne00*sizeof(float);
+
+ const int id = ggml_cuda_get_device();
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
- cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+ if (nbytes_shared <= smpbo) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+ if (!shared_memory_limit_raised[id]) {
+ CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
+ shared_memory_limit_raised[id] = true;
+ }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+ cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+ } else {
+ cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+ }
+ CUDA_CHECK(cudaGetLastError());
// Combine results from individual blocks:
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
}
void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0];
- const ggml_tensor * src1 = dst->src[1];
- const ggml_tensor * opt0 = dst->src[2];
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT(opt0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
- GGML_ASSERT(ggml_is_contiguous(opt0));
+ const ggml_tensor * grad = dst->src[0];
+ const ggml_tensor * src0f = dst->src[1];
+ const ggml_tensor * src1f = dst->src[2];
+
+ GGML_ASSERT(src0f->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1f->type == GGML_TYPE_F32);
+ GGML_ASSERT( grad->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_scalar(grad));
+ GGML_ASSERT(ggml_is_contiguous(src0f));
+ GGML_ASSERT(ggml_is_contiguous(src1f));
GGML_ASSERT(ggml_is_contiguous(dst));
- GGML_ASSERT(ggml_are_same_shape(src0, src1));
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_are_same_shape(src0f, src1f));
+ GGML_ASSERT(ggml_are_same_shape(src0f, dst));
- const int64_t ne00 = src0->ne[0];
- const int64_t nrows = ggml_nrows(src0);
+ const int64_t ne00 = src0f->ne[0];
+ const int64_t nrows = ggml_nrows(src0f);
- const float * src0_d = (const float *) src0->data;
- const float * src1_d = (const float *) src1->data;
- const float * opt0_d = (const float *) opt0->data;
- float * dst_d = (float *) dst->data;
+ const float * grad_d = (const float *) grad->data;
+ const float * src0f_d = (const float *) src0f->data;
+ const float * src1f_d = (const float *) src1f->data;
+ float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
const dim3 blocks_dim(WARP_SIZE, 1, 1);
const dim3 blocks_num(nrows, 1, 1);
- const int shmem = ne00*sizeof(float);
-
- cross_entropy_loss_back_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, opt0_d, dst_d, ne00);
+ const size_t nbytes_shared = ne00*sizeof(float);
+
+ const int id = ggml_cuda_get_device();
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+
+ if (nbytes_shared <= smpbo) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+ if (!shared_memory_limit_raised[id]) {
+ CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
+ shared_memory_limit_raised[id] = true;
+ }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+ cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
+ } else {
+ cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
+ }
}
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void k_get_rows(
- const void * src0, const int32_t * src1, dst_t * dst,
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
- size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+ const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
+ const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
+ /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
+ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
+ /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
+ const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
- const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
- const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
+ const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
- const int ib = i00/qk; // block index
- const int iqs = (i00%qk)/qr; // quant index
+ const int ib = i00/qk; // block index
+ const int iqs = (i00%qk)/qr; // quant index
const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
template<typename src0_t, typename dst_t>
static __global__ void k_get_rows_float(
- const src0_t * src0, const int32_t * src1, dst_t * dst,
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
- size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
-
- const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
- const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+ const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
+ const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
+ /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
+ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
+ /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
+ const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
+
+ const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
- const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+ const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = src0_row[i00];
}
+template<typename grad_t, typename dst_t>
+static __global__ void k_get_rows_back_float(
+ const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
+ const int col = blockIdx.x*blockDim.x + threadIdx.x;
+
+ if (col >= ncols) {
+ return;
+ }
+
+ const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
+
+ float sum = 0.0f;
+
+ for (int64_t i = 0; i < nrows_grad; ++i) {
+ if (rows[i] != dst_row) {
+ continue;
+ }
+ sum += grad[i*ncols + col];
+ }
+
+ dst[dst_row*ncols + col] = sum;
+}
+
template<int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
- const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+static void get_rows_cuda(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(ne00 % 2 == 0);
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
- src0_dd, src1_dd, dst_dd,
- ne00, /*ne01, ne02, ne03,*/
- /*ne10, ne11,*/ ne12, /*ne13,*/
- /* s0,*/ s1, s2, s3,
- /* nb00,*/ nb01, nb02, nb03,
- s10, s11, s12/*, s13*/);
+ src0_dd, src1_dd, dst_dd,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10, ne11,*/ ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
GGML_UNUSED(dst);
}
template<typename src0_t>
-static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
- const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+static void get_rows_cuda_float(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
GGML_TENSOR_BINARY_OP_LOCALS
+ GGML_ASSERT(ne13 == 1);
+
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
//const size_t s13 = nb13 / ggml_element_size(src1);
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
- src0_dd, src1_dd, dst_dd,
- ne00, /*ne01, ne02, ne03,*/
- /*ne10, ne11,*/ ne12, /*ne13,*/
- /* s0,*/ s1, s2, s3,
- /* nb00,*/ nb01, nb02, nb03,
- s10, s11, s12/*, s13*/);
+ src0_dd, src1_dd, dst_dd,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10, ne11,*/ ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
GGML_UNUSED(dst);
}
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
- const float * src0_d = (const float *)src0->data;
- const float * src1_d = (const float *)src1->data;
- float * dst_d = (float *)dst->data;
- cudaStream_t stream = ctx.stream();
+ const void * src0_d = (const void *) src0->data;
+ const int32_t * src1_d = (const int32_t *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
GGML_ASSERT(src1->type == GGML_TYPE_I32);
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
- GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
-
- const int32_t * src1_i32 = (const int32_t *) src1_d;
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
switch (src0->type) {
case GGML_TYPE_F16:
- get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
+ get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
break;
case GGML_TYPE_F32:
- get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
break;
case GGML_TYPE_Q4_0:
- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break;
case GGML_TYPE_Q4_1:
- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break;
case GGML_TYPE_Q5_0:
- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break;
case GGML_TYPE_Q5_1:
- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break;
case GGML_TYPE_Q8_0:
- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break;
default:
// TODO: k-quants
break;
}
}
+
+void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
+ const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const float * src0_d = (const float *) src0->data;
+ const int32_t * src1_d = (const int32_t *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ GGML_ASSERT(ne02*ne03 == 1);
+ GGML_ASSERT(ne12*ne13 == 1);
+ GGML_ASSERT(ne2*ne3 == 1);
+
+ const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
+ const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
+ const dim3 block_nums(block_num_x, ne1, 1);
+
+ k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
+}
#include "common.cuh"
#define CUDA_GET_ROWS_BLOCK_SIZE 256
+#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
case GGML_OP_GET_ROWS:
ggml_cuda_op_get_rows(ctx, dst);
break;
+ case GGML_OP_GET_ROWS_BACK:
+ ggml_cuda_op_get_rows_back(ctx, dst);
+ break;
case GGML_OP_DUP:
ggml_cuda_dup(ctx, dst);
break;
case GGML_OP_LEAKY_RELU:
ggml_cuda_op_leaky_relu(ctx, dst);
break;
+ case GGML_OP_SILU_BACK:
+ ggml_cuda_op_silu_back(ctx, dst);
+ break;
case GGML_OP_RMS_NORM:
ggml_cuda_op_rms_norm(ctx, dst);
break;
+ case GGML_OP_RMS_NORM_BACK:
+ ggml_cuda_op_rms_norm_back(ctx, dst);
+ break;
case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
case GGML_OP_SOFT_MAX:
ggml_cuda_op_soft_max(ctx, dst);
break;
+ case GGML_OP_SOFT_MAX_BACK:
+ ggml_cuda_op_soft_max_back(ctx, dst);
+ break;
case GGML_OP_ROPE:
ggml_cuda_op_rope(ctx, dst);
break;
}
} break;
case GGML_OP_OUT_PROD:
- return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
+ return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_GET_ROWS:
{
switch (op->src[0]->type) {
return false;
}
} break;
+ case GGML_OP_GET_ROWS_BACK:
+ {
+ return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
+ } break;
case GGML_OP_CPY:
{
ggml_type src0_type = op->src[0]->type;
}
return false;
} break;
+ case GGML_OP_SILU_BACK:
+ return ggml_is_contiguous(op->src[0]);
+ break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
+ case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
break;
case GGML_OP_NONE:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
return true;
+ case GGML_OP_SOFT_MAX_BACK: {
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
+ return max_bias == 0.0f;
+ }
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK: {
const size_t ts = ggml_type_size(op->src[0]->type);
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
- float2 mean_var = make_float2(0.f, 0.f);
+ x += int64_t(row)*ncols;
+ dst += int64_t(row)*ncols;
+
+ float2 mean_var = make_float2(0.0f, 0.0f);
for (int col = tid; col < ncols; col += block_size) {
- const float xi = x[row*ncols + col];
+ const float xi = x[col];
mean_var.x += xi;
mean_var.y += xi * xi;
}
// sum up partial sums
mean_var = warp_reduce_sum(mean_var);
- if (block_size > WARP_SIZE) {
+ if constexpr (block_size > WARP_SIZE) {
+ static_assert(block_size == 1024, "unexpected block_size");
__shared__ float2 s_sum[32];
- int warp_id = threadIdx.x / WARP_SIZE;
- int lane_id = threadIdx.x % WARP_SIZE;
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
const float inv_std = rsqrtf(var + eps);
for (int col = tid; col < ncols; col += block_size) {
- dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
+ dst[col] = (x[col] - mean) * inv_std;
}
}
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
// blockIdx.x: num_groups idx
// threadIdx.x: block_size idx
- int start = blockIdx.x * group_size;
- int end = start + group_size;
-
- start += threadIdx.x;
-
- if (end >= ne_elements) {
- end = ne_elements;
- }
+ const int start = blockIdx.x*group_size + threadIdx.x;
+ const int end = min(blockIdx.x*group_size + group_size, ne_elements);
float tmp = 0.0f; // partial sum for thread in warp
}
tmp = warp_reduce_sum(tmp);
- if (block_size > WARP_SIZE) {
+ if constexpr (block_size > WARP_SIZE) {
+ static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
- int warp_id = threadIdx.x / WARP_SIZE;
- int lane_id = threadIdx.x % WARP_SIZE;
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
tmp = warp_reduce_sum(tmp);
}
- float mean = tmp / group_size;
+ const float mean = tmp / group_size;
tmp = 0.0f;
for (int j = start; j < end; j += block_size) {
- float xi = x[j] - mean;
+ const float xi = x[j] - mean;
dst[j] = xi;
tmp += xi * xi;
}
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
- int warp_id = threadIdx.x / WARP_SIZE;
- int lane_id = threadIdx.x % WARP_SIZE;
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
tmp = warp_reduce_sum(tmp);
}
- float variance = tmp / group_size;
- float scale = rsqrtf(variance + eps);
+ const float variance = tmp / group_size;
+ const float scale = rsqrtf(variance + eps);
for (int j = start; j < end; j += block_size) {
dst[j] *= scale;
}
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
+ x += int64_t(row)*ncols;
+ dst += int64_t(row)*ncols;
+
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
- const float xi = x[row*ncols + col];
+ const float xi = x[col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
- if (block_size > WARP_SIZE) {
+ if constexpr (block_size > WARP_SIZE) {
+ static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
- int warp_id = threadIdx.x / WARP_SIZE;
- int lane_id = threadIdx.x % WARP_SIZE;
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
- dst[row*ncols + col] = scale * x[row*ncols + col];
+ dst[col] = scale * x[col];
+ }
+}
+
+template <int block_size>
+static __global__ void rms_norm_back_f32(
+ const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int tid = threadIdx.x;
+
+ grad += int64_t(row)*ncols;
+ xf += int64_t(row)*ncols;
+ dst += int64_t(row)*ncols;
+
+ float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
+ float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xfi = xf[col];
+ sum_xx += xfi * xfi;
+ sum_xg += xfi * grad[col];
+ }
+
+ // sum up partial sums
+ sum_xx = warp_reduce_sum(sum_xx);
+ sum_xg = warp_reduce_sum(sum_xg);
+ if constexpr (block_size > WARP_SIZE) {
+ static_assert(block_size == 1024, "unexpected block_size");
+ __shared__ float s_sum_xx[32];
+ __shared__ float s_sum_xg[32];
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum_xx[warp_id] = sum_xx;
+ s_sum_xg[warp_id] = sum_xg;
+ }
+ __syncthreads();
+
+ sum_xx = s_sum_xx[lane_id];
+ sum_xx = warp_reduce_sum(sum_xx);
+
+ sum_xg = s_sum_xg[lane_id];
+ sum_xg = warp_reduce_sum(sum_xg);
+ }
+
+ const float mean_eps = sum_xx / ncols + eps;
+ const float sum_eps = sum_xx + ncols*eps;
+
+ const float scale_grad = rsqrtf(mean_eps);
+ const float scale_x = -scale_grad * sum_xg/sum_eps;
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[col] = scale_grad*grad[col] + scale_x*xf[col];
}
}
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
- GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
}
-static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
+static void group_norm_f32_cuda(
+ const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
if (group_size < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
}
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
- GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
}
+static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_back_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_back_f32<1024><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
+ }
+}
+
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
+ GGML_ASSERT(eps >= 0.0f);
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
}
float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));
+ GGML_ASSERT(eps >= 0.0f);
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
+ GGML_ASSERT(eps >= 0.0f);
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
}
+
+void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * grad = dst->src[0]; // gradients
+ const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
+
+ const float * grad_d = (const float *) grad->data;
+ const float * src0f_d = (const float *) src0f->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(grad));
+
+ GGML_ASSERT( grad->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0f->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0f->ne[0];
+ const int64_t nrows = ggml_nrows(src0f);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+ GGML_ASSERT(eps >= 0.0f);
+
+ rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
+}
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ne01 == ne11);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
- GGML_ASSERT(ne2 == src0->ne[2]);
+ GGML_ASSERT(ne2 % src0->ne[2] == 0);
+ GGML_ASSERT(ne3 % src0->ne[3] == 0);
+
GGML_ASSERT(ne2 == src1->ne[2]);
- GGML_ASSERT(ne3 == src0->ne[3]);
GGML_ASSERT(ne3 == src1->ne[3]);
const float * src0_d = (const float *) src0->data;
const float alpha = 1.0f;
const float beta = 0.0f;
- GGML_ASSERT(ne2 == 1);
- GGML_ASSERT(ne3 == 1);
CUBLAS_CHECK(cublasSetStream(handle, stream));
const bool src1_T = ggml_is_transposed(src1);
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
- CUBLAS_CHECK(
- cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
- ne0, ne1, ne01,
- &alpha, src0_d, ne00,
- src1_d, ldb,
- &beta, dst_d, ne0));
+ // data strides in dimensions 2/3
+ const size_t s02 = nb02 / sizeof(float);
+ const size_t s03 = nb03 / sizeof(float);
+ const size_t s12 = nb12 / sizeof(float);
+ const size_t s13 = nb13 / sizeof(float);
+ const size_t s2 = nb2 / sizeof(float);
+ const size_t s3 = nb3 / sizeof(float);
+
+ // dps == dst per src0, used for group query attention
+ const int64_t dps2 = ne2 / ne02;
+ const int64_t dps3 = ne3 / ne03;
+
+ // TODO batched matrix multiplication
+ for (int64_t i3 = 0; i3 < ne3; ++i3) {
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
+ CUBLAS_CHECK(
+ cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
+ ne0, ne1, ne01,
+ &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
+ src1_d + i3 *s13 + i2 *s12, ldb,
+ &beta, dst_d + i3 *s3 + i2 *s2, ne0));
+ }
+ }
}
template<bool forward, bool has_ff, typename T>
static __global__ void rope_norm(
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
- const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
- const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
+ const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+ const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
template<bool forward, bool has_ff, typename T>
static __global__ void rope_neox(
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
- const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
- const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
+ const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+ const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
template<bool forward, bool has_ff, typename T>
static __global__ void rope_multi(
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
- const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
- const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
+ const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
+ const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
template<bool forward, bool has_ff, typename T>
static __global__ void rope_vision(
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
- const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
- const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
+ const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
+ const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
+ const float theta_scale, const float * freq_factors, const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
template<bool forward, typename T>
static void rope_norm_cuda(
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
- const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
- const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
+ const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
+ const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+ const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
template<bool forward, typename T>
static void rope_neox_cuda(
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
- const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
- const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
+ const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
+ const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+ const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
template<bool forward, typename T>
static void rope_multi_cuda(
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
- const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
- const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
+ const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
+ const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+ const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
template<bool forward, typename T>
static void rope_vision_cuda(
- const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
- const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
- const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
+ const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
+ const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+ const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
#include "common.cuh"
+#include "ggml.h"
#include "softmax.cuh"
+#include <cstdint>
template <typename T>
static __device__ __forceinline__ float t2f32(T val) {
return __half2float(val);
}
-template <bool vals_smem, int ncols_template, int block_size_template, typename T>
-static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+template <bool use_shared, int ncols_template, int block_size_template, typename T>
+static __global__ void soft_max_f32(
+ const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
+ const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
+ x += int64_t(rowx)*ncols;
+ mask += int64_t(rowy)*ncols * (mask != nullptr);
+ dst += int64_t(rowx)*ncols;
+
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
const int warp_id = threadIdx.x / WARP_SIZE;
extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations:
- float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
+ float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
float max_val = -INFINITY;
break;
}
- const int64_t ix = (int64_t)rowx*ncols + col;
- const int64_t iy = (int64_t)rowy*ncols + col;
-
- const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
+ const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
return;
}
- const int64_t idst = (int64_t)rowx*ncols + col;
- dst[idst] = vals[col] * inv_sum;
+ dst[col] = vals[col] * inv_sum;
+ }
+}
+
+static __global__ void soft_max_back_f32(
+ const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
+ const int tid = threadIdx.x;
+ const int rowx = blockIdx.x;
+
+ grad += int64_t(rowx)*ncols;
+ dstf += int64_t(rowx)*ncols;
+ dst += int64_t(rowx)*ncols;
+
+ float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
+
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
+ dgf_dot += dstf[col]*grad[col];
+ }
+
+ dgf_dot = warp_reduce_sum(dgf_dot);
+
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
+ dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
}
}
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
- const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
+ const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
const uint32_t n_head = nrows_x/nrows_y;
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
- if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
+ if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
switch (ncols_x) {
case 32:
- soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 64:
- soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 128:
- soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 256:
- soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 512:
- soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 1024:
- soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 2048:
- soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 4096:
- soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
default:
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
}
} else {
- const size_t shmem_low = WARP_SIZE*sizeof(float);
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
}
}
+static void soft_max_back_f32_cuda(
+ const float * grad, const float * dstf, float * dst,
+ const int ncols, const int nrows, const float scale, cudaStream_t stream) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ const dim3 block_nums(nrows, 1, 1);
+
+ soft_max_back_f32<<<block_nums, block_dims, 0, stream>>>(grad, dstf, dst, ncols, scale);
+}
+
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
- const float * src0_d = (const float *)src0->data;
- const void * src1_d = src1 ? (const void *)src1->data : nullptr;
+ const float * src0_d = (const float *) src0->data;
+ const void * src1_d = src1 ? (const void *) src1->data : nullptr;
+ float * dst_d = (float *) dst->data;
- float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
float scale = 1.0f;
float max_bias = 0.0f;
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
- memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
if (use_f16) {
- const half * src1_dd = (const half *)src1_d;
-
- soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
} else {
- const float * src1_dd = (const float *)src1_d;
-
- soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
}
}
+
+void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0]; // grad
+ const ggml_tensor * src1 = dst->src[1]; // forward pass output
+
+ const float * src0_d = (const float *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
+
+ GGML_ASSERT(max_bias == 0.0f);
+
+ soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
+}
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
+static __global__ void silu_back_f32(
+ const float * grad, const float * xf, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ const float xfi = xf[i];
+ const float s = 1.0f / (1.0f + expf(-xfi));
+ dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
+}
+
static __global__ void tanh_f32(const float * x, float * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
+ silu_back_f32<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
+}
+
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
+void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0]; // input from forward pass
+ const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output
+
+ const float * src0_d = (const float *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
+}
+
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
#define CUDA_STEP_BLOCK_SIZE 256
#define CUDA_GELU_BLOCK_SIZE 256
#define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_SILU_BACK_BLOCK_SIZE 256
#define CUDA_TANH_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_SIGMOID_BLOCK_SIZE 256
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
}
-// ggml_soft_max_back
+// ggml_soft_max_ext_back
-static struct ggml_tensor * ggml_soft_max_back_impl(
+static struct ggml_tensor * ggml_soft_max_ext_back_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ float scale,
+ float max_bias,
bool inplace) {
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
result->src[0] = a;
result->src[1] = b;
+ memcpy((float *) result->op_params + 0, &scale, sizeof(float));
+ memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
+
return result;
}
-struct ggml_tensor * ggml_soft_max_back(
+struct ggml_tensor * ggml_soft_max_ext_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
- struct ggml_tensor * b) {
- return ggml_soft_max_back_impl(ctx, a, b, false);
+ struct ggml_tensor * b,
+ float scale,
+ float max_bias) {
+ return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
}
-struct ggml_tensor * ggml_soft_max_back_inplace(
+struct ggml_tensor * ggml_soft_max_ext_back_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
- struct ggml_tensor * b) {
- return ggml_soft_max_back_impl(ctx, a, b, true);
+ struct ggml_tensor * b,
+ float scale,
+ float max_bias) {
+ return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
}
// ggml_rope
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c) {
- GGML_ASSERT(ggml_are_same_shape(a, b));
- GGML_ASSERT(ggml_is_scalar(c));
+ GGML_ASSERT(ggml_is_scalar(a));
+ GGML_ASSERT(ggml_are_same_shape(b, c));
- struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
result->src[0] = a;
}
static void ggml_compute_backward(
- struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
+ struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
struct ggml_tensor * tensor = cgraph->nodes[i];
struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
if (src0_needs_grads) {
float eps;
memcpy(&eps, tensor->op_params, sizeof(float));
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
}
} break;
case GGML_OP_MUL_MAT: {
} break;
case GGML_OP_SOFT_MAX: {
if (src0_needs_grads) {
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
+
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
}
GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
} break;
const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
- ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
}
} break;
case GGML_OP_POOL_2D: {
} break;
case GGML_UNARY_OP_SILU: {
if (src0_needs_grads) {
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
}
} break;
case GGML_UNARY_OP_EXP: {
} break;
case GGML_OP_CROSS_ENTROPY_LOSS: {
if (src0_needs_grads) {
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
}
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
} break;