#include "norm.cuh"
+#include <cstdint>
template <int block_size>
-static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
- const int row = blockIdx.x*blockDim.y + threadIdx.y;
- const int tid = threadIdx.x;
+static __global__ void norm_f32(
+ const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+ const int64_t stride_sample, const float eps) {
+ const int nrows = gridDim.x;
+ const int nchannels = gridDim.y;
- x += int64_t(row)*ncols;
- dst += int64_t(row)*ncols;
+ const int row = blockIdx.x;
+ const int channel = blockIdx.y;
+ const int sample = blockIdx.z;
+ const int tid = threadIdx.x;
+
+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float2 mean_var = make_float2(0.0f, 0.0f);
}
template <int block_size>
-static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
- const int row = blockIdx.x*blockDim.y + threadIdx.y;
- const int tid = threadIdx.x;
+static __global__ void rms_norm_f32(
+ const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+ const int64_t stride_sample, const float eps) {
+ const int nrows = gridDim.x;
+ const int nchannels = gridDim.y;
+
+ const int row = blockIdx.x;
+ const int channel = blockIdx.y;
+ const int sample = blockIdx.z;
+ const int tid = threadIdx.x;
- x += int64_t(row)*ncols;
- dst += int64_t(row)*ncols;
+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
}
}
-static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+static void norm_f32_cuda(
+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+ const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
- norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+ norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
- norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+ norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
}
}
-static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+static void rms_norm_f32_cuda(
+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+ const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+ rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+ rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
- const float * src0_d = (const float *)src0->data;
- float * dst_d = (float *)dst->data;
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
- GGML_ASSERT(ggml_is_contiguous(src0));
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
- const int64_t ne00 = src0->ne[0];
- const int64_t nrows = ggml_nrows(src0);
+ GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
- norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+ const size_t ts0 = ggml_type_size(src0->type);
+ GGML_ASSERT(nb00 == ts0);
+ const int64_t s01 = nb01 / ts0;
+ const int64_t s02 = nb02 / ts0;
+ const int64_t s03 = nb03 / ts0;
+
+ norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
- GGML_ASSERT(ggml_is_contiguous(src0));
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
- const float * src0_d = (const float *)src0->data;
- float * dst_d = (float *)dst->data;
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
- GGML_ASSERT(ggml_is_contiguous(src0));
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
- const int64_t ne00 = src0->ne[0];
- const int64_t nrows = ggml_nrows(src0);
+ GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
- rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+ const size_t ts0 = ggml_type_size(src0->type);
+ GGML_ASSERT(nb00 == ts0);
+ const int64_t s01 = nb01 / ts0;
+ const int64_t s02 = nb02 / ts0;
+ const int64_t s03 = nb03 / ts0;
+
+ rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
struct test_norm : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
- float eps;
+ const bool v; // whether a is a non-contiguous view
+ const float eps;
std::string vars() override {
- return VARS_TO_STR3(type, ne, eps);
+ return VARS_TO_STR4(type, ne, v, eps);
}
test_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
+ bool v = false,
float eps = 1e-6f)
- : type(type), ne(ne), eps(eps) {}
+ : type(type), ne(ne), v(v), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
+ if (v) {
+ a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
+ ggml_set_name(a, "view of a");
+ }
+
ggml_tensor * out = ggml_norm(ctx, a, eps);
ggml_set_name(out, "out");
struct test_rms_norm : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
- float eps;
+ const bool v; // whether a is a non-contiguous view
+ const float eps;
std::string vars() override {
- return VARS_TO_STR3(type, ne, eps);
+ return VARS_TO_STR4(type, ne, v, eps);
}
test_rms_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
+ bool v = false,
float eps = 1e-6f)
- : type(type), ne(ne), eps(eps) {}
+ : type(type), ne(ne), v(v), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_name(a, "a");
+ if (v) {
+ a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
+ ggml_set_name(a, "view of a");
+ }
+
ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
ggml_set_name(out, "out");
struct test_rms_norm_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
- float eps;
+ const float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
const float eps;
std::string vars() override {
- return VARS_TO_STR3(type, ne, num_groups);
+ return VARS_TO_STR4(type, ne, num_groups, eps);
}
test_group_norm(ggml_type type = GGML_TYPE_F32,
test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_silu_back());
- for (float eps : {0.0f, 1e-7f, 1e-4f, 1e-1f}) {
- test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
- test_cases.emplace_back(new test_rms_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+ for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
+ for (bool v : {false, true}) {
+ test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
+ test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
+ }
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}