#include "binbcast.cuh"
#include <cstdint>
+#include <utility>
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
return a / b;
}
-template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+
+
+template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
- int ne0, int ne1, int ne2, int ne3,
- int ne10, int ne11, int ne12, int ne13,
- /*int s0, */ int s1, int s2, int s3,
- /*int s00,*/ int s01, int s02, int s03,
- /*int s10,*/ int s11, int s12, int s13) {
+ const int ne0, const int ne1, const int ne2, const int ne3,
+ const int ne10, const int ne11, const int ne12, const int ne13,
+ /*int s0, */ const int s1, const int s2, const int s3,
+ /*int s00,*/ const int s01, const int s02, const int s03,
+ /*int s10,*/ const int s11, const int s12, const int s13,
+ src1_ptrs... src1s) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
- const src0_t * src0_row = src0 + i_src0;
- const src1_t * src1_row = src1 + i_src1;
+ const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
const int i10 = i0 % ne10;
- dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+
+ float result = src0_row ? (float) src0_row[i0] : 0.0f;
+ result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+
+ dst_row[i0] = (dst_t) result;
}
}
-template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
-static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
- int ne0, int ne1, int ne2, int ne3,
- int ne10, int ne11, int ne12, int ne13,
- /*int s0, */ int s1, int s2, int s3,
- /*int s00,*/ int s01, int s02, int s03,
- /*int s10,*/ int s11, int s12, int s13) {
-
+template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
+static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+ const int ne0, const int ne1, const int ne2,const int ne3,
+ const int ne10, const int ne11, const int ne12, const int ne13,
+ /*int s0, */ const int s1, const int s2, const int s3,
+ /*int s00,*/ const int s01, const int s02, const int s03,
+ /*int s10,*/ const int s11, const int s12, const int s13,
+ src1_ptrs ... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int i3 = i/(ne2*ne1*ne0);
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
- const src0_t * src0_row = src0 + i_src0;
- const src1_t * src1_row = src1 + i_src1;
+ const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10;
- dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+
+ float result = src0_row ? (float) src0_row[i0] : 0.0f;
+ result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+
+ dst_row[i0] = (dst_t) result;
+}
+
+template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I>
+static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+ cudaStream_t stream, std::index_sequence<I...>) {
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ int nr0 = ne10 / ne0;
+ int nr1 = ne11 / ne1;
+ int nr2 = ne12 / ne2;
+ int nr3 = ne13 / ne3;
+
+ int nr[4] = { nr0, nr1, nr2, nr3 };
+
+ int64_t cne[] = { ne0, ne1, ne2, ne3 };
+ int64_t cne0[] = { ne00, ne01, ne02, ne03 };
+ int64_t cne1[] = { ne10, ne11, ne12, ne13 };
+
+ size_t cnb[] = { nb0, nb1, nb2, nb3 };
+ size_t cnb0[] = { nb00, nb01, nb02, nb03 };
+ size_t cnb1[] = { nb10, nb11, nb12, nb13 };
+
+ auto collapse = [](int64_t cne[]) {
+ cne[0] *= cne[1];
+ cne[1] = cne[2];
+ cne[2] = cne[3];
+ cne[3] = 1;
+ };
+
+ auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
+ cnb[1] *= cne[1];
+ cnb[2] *= cne[2];
+ cnb[3] *= cne[3];
+ };
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+ for (int i = 0; i < 4; i++) {
+ if (nr[i] != 1) {
+ break;
+ }
+ if (i > 0) {
+ collapse_nb(cnb, cne);
+ collapse_nb(cnb0, cne0);
+ collapse_nb(cnb1, cne1);
+ collapse(cne);
+ collapse(cne0);
+ collapse(cne1);
+ }
+ }
+ }
+
+ {
+ int64_t ne0 = cne[0];
+ int64_t ne1 = cne[1];
+ int64_t ne2 = cne[2];
+ int64_t ne3 = cne[3];
+
+ //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
+ //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
+ //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
+ //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
+
+ int64_t ne10 = cne1[0];
+ int64_t ne11 = cne1[1];
+ int64_t ne12 = cne1[2];
+ int64_t ne13 = cne1[3];
+
+ size_t nb0 = cnb[0];
+ size_t nb1 = cnb[1];
+ size_t nb2 = cnb[2];
+ size_t nb3 = cnb[3];
+
+ size_t nb00 = cnb0[0];
+ size_t nb01 = cnb0[1];
+ size_t nb02 = cnb0[2];
+ size_t nb03 = cnb0[3];
+
+ size_t nb10 = cnb1[0];
+ size_t nb11 = cnb1[1];
+ size_t nb12 = cnb1[2];
+ size_t nb13 = cnb1[3];
+
+ size_t s0 = nb0 / sizeof(dst_t);
+ size_t s1 = nb1 / sizeof(dst_t);
+ size_t s2 = nb2 / sizeof(dst_t);
+ size_t s3 = nb3 / sizeof(dst_t);
+
+ size_t s10 = nb10 / sizeof(src1_t);
+ size_t s11 = nb11 / sizeof(src1_t);
+ size_t s12 = nb12 / sizeof(src1_t);
+ size_t s13 = nb13 / sizeof(src1_t);
+
+ size_t s00 = nb00 / sizeof(src0_t);
+ size_t s01 = nb01 / sizeof(src0_t);
+ size_t s02 = nb02 / sizeof(src0_t);
+ size_t s03 = nb03 / sizeof(src0_t);
+
+ GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
+
+ GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
+
+ GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
+
+ GGML_ASSERT(s0 == 1);
+ GGML_ASSERT(s00 == 1);
+ GGML_ASSERT(s10 == 1);
+
+ const int block_size = 128;
+
+ int64_t hne0 = std::max(ne0 / 2LL, 1LL);
+
+ dim3 block_dims;
+ block_dims.x = std::min<unsigned int>(hne0, block_size);
+ block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
+ block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
+
+ dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
+ (ne1 + block_dims.y - 1) / block_dims.y,
+ (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
+
+ if (block_nums.z > 65535) {
+ int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
+ k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
+ <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13,
+ (const src1_t *) dst->src[I + 1]->data...);
+ } else {
+ k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
+ <<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13,
+ (const src1_t *) dst->src[I + 1]->data...);
+ }
+ }
}
template <typename T>
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
}
-template<float (*bin_op)(const float, const float)>
+template <float (*bin_op)(const float, const float), int n_fuse = 1>
struct bin_bcast_cuda {
template<typename src0_t, typename src1_t, typename dst_t>
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
cudaStream_t stream) {
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- int nr0 = ne10/ne0;
- int nr1 = ne11/ne1;
- int nr2 = ne12/ne2;
- int nr3 = ne13/ne3;
-
- int nr[4] = { nr0, nr1, nr2, nr3 };
-
- // collapse dimensions until first broadcast dimension
- int64_t cne[] = {ne0, ne1, ne2, ne3};
- int64_t cne0[] = {ne00, ne01, ne02, ne03};
- int64_t cne1[] = {ne10, ne11, ne12, ne13};
-
- size_t cnb[] = {nb0, nb1, nb2, nb3};
- size_t cnb0[] = {nb00, nb01, nb02, nb03};
- size_t cnb1[] = {nb10, nb11, nb12, nb13};
-
- auto collapse = [](int64_t cne[]) {
- cne[0] *= cne[1];
- cne[1] = cne[2];
- cne[2] = cne[3];
- cne[3] = 1;
- };
-
- auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
- cnb[1] *= cne[1];
- cnb[2] *= cne[2];
- cnb[3] *= cne[3];
- };
-
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
- for (int i = 0; i < 4; i++) {
- if (nr[i] != 1) {
- break;
- }
- if (i > 0) {
- collapse_nb(cnb, cne);
- collapse_nb(cnb0, cne0);
- collapse_nb(cnb1, cne1);
- collapse(cne);
- collapse(cne0);
- collapse(cne1);
- }
- }
- }
-
- {
- int64_t ne0 = cne[0];
- int64_t ne1 = cne[1];
- int64_t ne2 = cne[2];
- int64_t ne3 = cne[3];
-
- //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
- //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
- //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
- //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
-
- int64_t ne10 = cne1[0];
- int64_t ne11 = cne1[1];
- int64_t ne12 = cne1[2];
- int64_t ne13 = cne1[3];
-
- size_t nb0 = cnb[0];
- size_t nb1 = cnb[1];
- size_t nb2 = cnb[2];
- size_t nb3 = cnb[3];
-
- size_t nb00 = cnb0[0];
- size_t nb01 = cnb0[1];
- size_t nb02 = cnb0[2];
- size_t nb03 = cnb0[3];
-
- size_t nb10 = cnb1[0];
- size_t nb11 = cnb1[1];
- size_t nb12 = cnb1[2];
- size_t nb13 = cnb1[3];
-
- size_t s0 = nb0 / sizeof(dst_t);
- size_t s1 = nb1 / sizeof(dst_t);
- size_t s2 = nb2 / sizeof(dst_t);
- size_t s3 = nb3 / sizeof(dst_t);
-
- size_t s10 = nb10 / sizeof(src1_t);
- size_t s11 = nb11 / sizeof(src1_t);
- size_t s12 = nb12 / sizeof(src1_t);
- size_t s13 = nb13 / sizeof(src1_t);
-
- size_t s00 = nb00 / sizeof(src0_t);
- size_t s01 = nb01 / sizeof(src0_t);
- size_t s02 = nb02 / sizeof(src0_t);
- size_t s03 = nb03 / sizeof(src0_t);
-
- GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
- GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
- GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
- GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
-
- GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
- GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
- GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
- GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
-
- GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
- GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
- GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
- GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
-
- GGML_ASSERT(s0 == 1);
- GGML_ASSERT(s00 == 1);
- GGML_ASSERT(s10 == 1);
-
- const int block_size = 128;
-
- int64_t hne0 = std::max(ne0/2LL, 1LL);
-
- dim3 block_dims;
- block_dims.x = std::min<unsigned int>(hne0, block_size);
- block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
- block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
-
- dim3 block_nums(
- (hne0 + block_dims.x - 1) / block_dims.x,
- (ne1 + block_dims.y - 1) / block_dims.y,
- (ne2*ne3 + block_dims.z - 1) / block_dims.z
- );
-
- if (block_nums.z > 65535) {
- // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
- int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
- k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
- src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00, */ s01, s02, s03,
- /* s10, */ s11, s12, s13);
- } else {
- k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
- src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00, */ s01, s02, s03,
- /* s10, */ s11, s12, s13);
- }
- }
+ launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>(
+ src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});
}
};
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
+template <float (*op)(const float, const float), int n_fuse>
+static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ cudaStream_t stream = ctx.stream();
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,
+ (const float *) src0->data, (const float *) src1->data, (float *) dst->data,
+ stream, std::make_index_sequence<n_fuse>{});
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,
+ (const half *) src0->data, (const half *) src1->data, (half *) dst->data,
+ stream, std::make_index_sequence<n_fuse>{});
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+ launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,
+ (const half *) src0->data, (const float *) src1->data, (half *) dst->data,
+ stream, std::make_index_sequence<n_fuse>{});
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,
+ (const half *) src0->data, (const float *) src1->data, (float *) dst->data,
+ stream, std::make_index_sequence<n_fuse>{});
+ } else {
+ fprintf(stderr,
+ "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
+ __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ABORT("fatal error");
+ }
+}
+
+
+void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
+ GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
+
+ switch (n_fuse) {
+ case 2:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst);
+ break;
+ case 3:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst);
+ break;
+ case 4:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst);
+ break;
+ case 5:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst);
+ break;
+ case 6:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst);
+ break;
+ case 7:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst);
+ break;
+ case 8:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst);
+ break;
+ default:
+ GGML_ASSERT(false && "Unsupported n_fuse value");
+ }
+}
+
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
}
}
-template <int block_size, bool do_multiply = false>
-static __global__ void rms_norm_f32(
- const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
- const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
- const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
- const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
+template <int block_size, bool do_multiply = false, bool do_add = false>
+static __global__ void rms_norm_f32(const float * x, float * dst,
+ const int ncols,
+ const int64_t stride_row,
+ const int64_t stride_channel,
+ const int64_t stride_sample,
+ const float eps,
+ const float * mul = nullptr,
+ const int64_t mul_stride_row = 0,
+ const int64_t mul_stride_channel = 0,
+ const int64_t mul_stride_sample = 0,
+ const int mul_ncols = 0,
+ const int mul_nrows = 0,
+ const int mul_nchannels = 0,
+ const int mul_nsamples = 0,
+ const float * add = nullptr,
+ const int64_t add_stride_row = 0,
+ const int64_t add_stride_channel = 0,
+ const int64_t add_stride_sample = 0,
+ const int add_ncols = 0,
+ const int add_nrows = 0,
+ const int add_nchannels = 0,
+ const int add_nsamples = 0) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
}
+ if constexpr (do_add) {
+ const int add_row = row % add_nrows;
+ const int add_channel = channel % add_nchannels;
+ const int add_sample = sample % add_nsamples;
+ add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
+ }
+
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
- if constexpr (do_multiply) {
+ if constexpr (do_multiply && do_add) {
+ const int mul_col = col % mul_ncols;
+ const int add_col = col % add_ncols;
+ dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
+ } else if constexpr (do_multiply) {
const int mul_col = col % mul_ncols;
dst[col] = scale * x[col] * mul[mul_col];
+ } else if constexpr (do_add) {
+ const int add_col = col % add_ncols;
+ dst[col] += add[add_col];
} else {
dst[col] = scale * x[col];
}
}
}
-static void rms_norm_mul_f32_cuda(
- const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
- const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
- const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
- const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
- const float eps, cudaStream_t stream) {
+static void rms_norm_mul_f32_cuda(const float * x,
+ const float * mul,
+ const float * add,
+ float * dst,
+ const int ncols,
+ const int nrows,
+ const int nchannels,
+ const int nsamples,
+ const int64_t stride_row,
+ const int64_t stride_channel,
+ const int64_t stride_sample,
+ const int64_t mul_stride_row,
+ const int64_t mul_stride_channel,
+ const int64_t mul_stride_sample,
+ const int mul_ncols,
+ const int mul_nrows,
+ const int mul_nchannels,
+ const int mul_nsamples,
+ const int64_t add_stride_row,
+ const int64_t add_stride_channel,
+ const int64_t add_stride_sample,
+ const int add_ncols,
+ const int add_nrows,
+ const int add_nchannels,
+ const int add_nsamples,
+ const float eps,
+ cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return;
}
- if (ncols < 1024) {
- const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ if (add == nullptr) {
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ }
} else {
- const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_f32<WARP_SIZE, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ add, add_stride_row, add_stride_channel, add_stride_sample,
+ add_ncols, add_nrows, add_nchannels, add_nsamples);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ add, add_stride_row, add_stride_channel, add_stride_sample,
+ add_ncols, add_nrows, add_nchannels, add_nsamples);
+ }
}
}
const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3];
- rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
+ rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
+ ne00, ne01, ne02, ne03,
+ /*s00*/ s01, s02, s03,
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ /*add_s00*/ 0, 0, 0,
+ 0, 0, 0, 0,
+ eps, stream);
+}
+
+void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
+ ggml_tensor * dst,
+ ggml_tensor * mul_tensor,
+ ggml_tensor * add_tensor) {
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
+ float eps = 0.0f;
+
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ const float * src0_d = (const float *) rms_norm_src->data;
+ const float * mul_d = nullptr;
+ const ggml_tensor * mul_src = nullptr;
+
+ if (mul_tensor->src[0] == dst) {
+ mul_d = (float *) mul_tensor->src[1]->data;
+ mul_src = mul_tensor->src[1];
+ } else if (mul_tensor->src[1] == dst) {
+ mul_d = (float *) mul_tensor->src[0]->data;
+ mul_src = mul_tensor->src[0];
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const float * add_d = nullptr;
+ const ggml_tensor * add_src = nullptr;
+
+ if (add_tensor->src[0] == mul_tensor) {
+ add_d = (float *) add_tensor->src[1]->data;
+ add_src = add_tensor->src[1];
+ } else if (add_tensor->src[1] == mul_tensor) {
+ add_d = (float *) add_tensor->src[0]->data;
+ add_src = add_tensor->src[0];
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ float * dst_d = (float *) add_tensor->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(eps >= 0.0f);
+
+ const int64_t ne00 = rms_norm_src->ne[0];
+ const int64_t ne01 = rms_norm_src->ne[1];
+ const int64_t ne02 = rms_norm_src->ne[2];
+ const int64_t ne03 = rms_norm_src->ne[3];
+
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
+
+ const size_t ts_mul = ggml_type_size(mul_src->type);
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
+
+ const int mul_ncols = mul_src->ne[0];
+ const int mul_nrows = mul_src->ne[1];
+ const int mul_nchannels = mul_src->ne[2];
+ const int mul_nsamples = mul_src->ne[3];
+
+ const size_t ts_add = ggml_type_size(add_src->type);
+ GGML_ASSERT(add_src->nb[0] == ts_add);
+ const int64_t add_s01 = add_src->nb[1] / ts_add;
+ const int64_t add_s02 = add_src->nb[2] / ts_add;
+ const int64_t add_s03 = add_src->nb[3] / ts_add;
+
+ const int add_ncols = add_src->ne[0];
+ const int add_nrows = add_src->ne[1];
+ const int add_nchannels = add_src->ne[2];
+ const int add_nsamples = add_src->ne[3];
+
+ rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
+ ne00,ne01, ne02, ne03,
+ /*s00*/ s01, s02, s03,
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ /*add_s00*/ add_s01, add_s02, add_s03,
+ add_ncols, add_nrows, add_nchannels, add_nsamples,
+ eps, stream);
}
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {