#include <cstddef>
#include <cstdint>
#include <float.h>
+#include <initializer_list>
#include <limits>
#include <map>
#include <memory>
}
#endif
+static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
+ return false;
+ }
+
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
+
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
+
+ //rms norm only supports F32
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
+ mul->src[1]->type != GGML_TYPE_F32 ||
+ mul->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ //if rms norm is the B operand, then we don't handle broadcast
+ if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
+ return false;
+ }
+
+ //rms_norm kernel assumes contigous rows
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
// flag used to determine whether it is an integrated_gpu
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
// With the use of CUDA graphs, the execution will be performed by the graph launch.
if (!use_cuda_graph || cuda_graph_update_required) {
+
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
continue;
}
+ static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
+ if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
+ ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
+ i++;
+ continue;
+ }
#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {
}
}
-template <int block_size>
+template <int block_size, bool do_multiply = false>
static __global__ void rms_norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
- const int64_t stride_sample, const float eps) {
+ const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
+ const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
+ const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
+ if constexpr (do_multiply) {
+ const int mul_row = row % mul_nrows;
+ const int mul_channel = channel % mul_nchannels;
+ const int mul_sample = sample % mul_nsamples;
+ mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
+ }
+
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
- dst[col] = scale * x[col];
+ if constexpr (do_multiply) {
+ const int mul_col = col % mul_ncols;
+ dst[col] = scale * x[col] * mul[mul_col];
+ } else {
+ dst[col] = scale * x[col];
+ }
}
}
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ }
+}
+
+static void rms_norm_mul_f32_cuda(
+ const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
+ const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
+ const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
+ const float eps, cudaStream_t stream) {
+ const dim3 blocks_num(nrows, nchannels, nsamples);
+ if (mul == nullptr) {
+ rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
+ return;
+ }
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
} else {
const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
}
}
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}
+void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
+ float eps = 0.0f;
+
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ const float * src0_d = (const float *) rms_norm_src->data;
+ const float * mul_d = nullptr;
+ const ggml_tensor * mul_src = nullptr;
+
+ if (mul_tensor->src[0] == dst) {
+ mul_d = (float *) mul_tensor->src[1]->data;
+ mul_src = mul_tensor->src[1];
+ } else if(mul_tensor->src[1] == dst) {
+ mul_d = (float *) mul_tensor->src[0]->data;
+ mul_src = mul_tensor->src[0];
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ float * dst_d = (float *) mul_tensor->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(eps >= 0.0f);
+
+ const int64_t ne00 = rms_norm_src->ne[0];
+ const int64_t ne01 = rms_norm_src->ne[1];
+ const int64_t ne02 = rms_norm_src->ne[2];
+ const int64_t ne03 = rms_norm_src->ne[3];
+
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
+
+ const size_t ts_mul = ggml_type_size(mul_src->type);
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
+
+ const int mul_ncols = mul_src->ne[0];
+ const int mul_nrows = mul_src->ne[1];
+ const int mul_nchannels = mul_src->ne[2];
+ const int mul_nsamples = mul_src->ne[3];
+
+ rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
+}
+
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * grad = dst->src[0]; // gradients
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
const ggml_type type;
const std::array<int64_t, 4> ne;
const float eps;
+ const bool broadcast;
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
bool run_whole_graph() override { return true; }
std::string vars() override {
- return VARS_TO_STR3(type, ne, eps);
+ return VARS_TO_STR4(type, ne, eps, broadcast);
}
test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
- float eps = 1e-6f)
- : type(type), ne(ne), eps(eps) {}
+ float eps = 1e-6f, bool broadcast = false)
+ : type(type), ne(ne), eps(eps), broadcast(broadcast) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
- ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+ std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
+
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data());
+
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_set_param(b);
}
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+ test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
}
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));