#endif // FP16_AVAILABLE
}
+enum class block_reduce_method {
+ MAX,
+ SUM,
+};
+
+template<block_reduce_method method_t, typename T>
+struct block_reduce_policy;
+
+template <typename T, typename... Ts>
+inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
+
+template<typename...>
+inline constexpr bool ggml_cuda_dependent_false_v = false;
+
+template <typename T> struct block_reduce_policy<block_reduce_method::SUM, T> {
+ static __device__ T reduce(T val) {
+ if constexpr(is_any<T, float, float2, half2, int>) {
+ return warp_reduce_sum(val);
+ } else {
+ static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
+ }
+ }
+
+ static __device__ T sentinel() {
+ if constexpr (std::is_same_v<T, float>) {
+ return 0.0f;
+ } else if constexpr (std::is_same_v<T, float2>) {
+ return make_float2(0.0f, 0.0f);
+ } else if constexpr (std::is_same_v<T, half2>) {
+ return make_half2(0.0f, 0.0f);
+ } else if constexpr (std::is_same_v<T, int>) {
+ return 0;
+ } else {
+ static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
+ }
+ }
+};
+
+template <typename T> struct block_reduce_policy<block_reduce_method::MAX, T> {
+ static __device__ T reduce(T val) {
+ if constexpr (is_any<T, float, half2>) {
+ return warp_reduce_max(val);
+ } else {
+ static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
+ }
+ }
+
+ static __device__ T sentinel() {
+ if constexpr (std::is_same_v<T, float>) {
+ return -INFINITY;
+ } else if constexpr (std::is_same_v<T, half2>) {
+ return make_half2(-INFINITY, -INFINITY);
+ } else {
+ static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
+ }
+ }
+};
+
+template <block_reduce_method reduce_method_t, const unsigned int block_size_template = 0, typename T>
+static __device__ T block_reduce(T val, T * shared_vals) {
+ val = block_reduce_policy<reduce_method_t, T>::reduce(val);
+ const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
+ if (block_size > WARP_SIZE) {
+ assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0);
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ shared_vals[warp_id] = val;
+ }
+ __syncthreads();
+ val = block_reduce_policy<reduce_method_t, T>::sentinel();
+ if (lane_id < (static_cast<int>(block_size) / WARP_SIZE)) {
+ val = shared_vals[lane_id];
+ }
+ return block_reduce_policy<reduce_method_t, T>::reduce(val);
+ }
+
+ return val;
+}
+
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#ifdef FP16_AVAILABLE
case GGML_OP_L2_NORM:
return true;
case GGML_OP_RMS_NORM_BACK:
- return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
+ return ggml_is_contiguous(op->src[0]);
break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
}
// sum up partial sums
- mean_var = warp_reduce_sum(mean_var);
- if constexpr (block_size > WARP_SIZE) {
- static_assert(block_size == 1024, "unexpected block_size");
- __shared__ float2 s_sum[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = mean_var;
- }
- __syncthreads();
- mean_var = s_sum[lane_id];
- mean_var = warp_reduce_sum(mean_var);
- }
+ extern __shared__ float2 s_sum2[];
+ mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2);
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
tmp += x[j];
}
- tmp = warp_reduce_sum(tmp);
- if constexpr (block_size > WARP_SIZE) {
- static_assert(block_size == 1024, "unexpected block_size");
- __shared__ float s_sum[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = tmp;
- }
- __syncthreads();
- tmp = s_sum[lane_id];
- tmp = warp_reduce_sum(tmp);
- }
+ extern __shared__ float s_sum[];
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float mean = tmp / group_size;
tmp = 0.0f;
tmp += xi * xi;
}
- tmp = warp_reduce_sum(tmp);
- if (block_size > WARP_SIZE) {
- __shared__ float s_sum[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = tmp;
- }
- __syncthreads();
- tmp = s_sum[lane_id];
- tmp = warp_reduce_sum(tmp);
- }
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float variance = tmp / group_size;
const float scale = rsqrtf(variance + eps);
}
// sum up partial sums
- tmp = warp_reduce_sum(tmp);
- if constexpr (block_size > WARP_SIZE) {
- static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
- __shared__ float s_sum[32];
- const int warp_id = tid / WARP_SIZE;
- const int lane_id = tid % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = tmp;
- }
- __syncthreads();
- tmp = 0.0f;
- if (lane_id < (block_size / WARP_SIZE)) {
- tmp = s_sum[lane_id];
- }
- tmp = warp_reduce_sum(tmp);
- }
+ extern __shared__ float s_sum[];
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
}
// sum up partial sums
- tmp = warp_reduce_sum(tmp);
- if constexpr (block_size > WARP_SIZE) {
- static_assert(block_size == 1024, "unexpected block_size");
- __shared__ float s_sum[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = tmp;
- }
- __syncthreads();
- tmp = s_sum[lane_id];
- tmp = warp_reduce_sum(tmp);
- }
+ extern __shared__ float s_sum[];
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
- norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
} else {
const dim3 block_dims(1024, 1, 1);
- group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+ group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);
}
}
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
- rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
- rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
+ rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
+ rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
}
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
- rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
+ rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
+ rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
- l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
}
// sum up partial sums
- sum = warp_reduce_sum(sum);
- if (blockDim.x > WARP_SIZE) {
- assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
- __shared__ float s_sum[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = sum;
- }
- __syncthreads();
- sum = 0.0f;
- if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
- sum = s_sum[lane_id];
- }
- sum = warp_reduce_sum(sum);
- }
+ __shared__ float shared_vals[32];
+ sum = block_reduce<block_reduce_method::SUM>(sum, shared_vals);
if (col != 0) {
return;
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
-
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
extern __shared__ float data_soft_max_f32[];
}
// find the max value in the block
- max_val = warp_reduce_max(max_val);
- if (block_size > WARP_SIZE) {
- if (warp_id == 0) {
- buf_iw[lane_id] = -INFINITY;
- }
- __syncthreads();
-
- if (lane_id == 0) {
- buf_iw[warp_id] = max_val;
- }
- __syncthreads();
-
- max_val = buf_iw[lane_id];
- max_val = warp_reduce_max(max_val);
- }
+ max_val = block_reduce<block_reduce_method::MAX, block_size_template>(max_val, buf_iw);
float tmp = 0.0f; // partial sum
}
// find the sum of exps in the block
- tmp = warp_reduce_sum(tmp);
- if (block_size > WARP_SIZE) {
- __syncthreads();
- if (warp_id == 0) {
- buf_iw[lane_id] = 0.0f;
- }
- __syncthreads();
-
- if (lane_id == 0) {
- buf_iw[warp_id] = tmp;
- }
- __syncthreads();
-
- tmp = buf_iw[lane_id];
- tmp = warp_reduce_sum(tmp);
- }
+ tmp = block_reduce<block_reduce_method::SUM, block_size_template>(tmp, buf_iw);
if (sinks) {
tmp += expf(sinks[i02] - max_val);
}
}
-
-// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated
-static __device__ float two_stage_warp_reduce_max(float val) {
- val = warp_reduce_max(val);
- if (blockDim.x > WARP_SIZE) {
- assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
- __shared__ float local_vals[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- local_vals[warp_id] = val;
- }
- __syncthreads();
- val = -INFINITY;
- if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
- val = local_vals[lane_id];
- }
- return warp_reduce_max(val);
- } else {
- return val;
- }
-}
-
-static __device__ float two_stage_warp_reduce_sum(float val) {
- val = warp_reduce_sum(val);
- if (blockDim.x > WARP_SIZE) {
- assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
- __shared__ float local_vals[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- local_vals[warp_id] = val;
- }
- __syncthreads();
- val = 0.0f;
- if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
- val = local_vals[lane_id];
- }
- return warp_reduce_sum(val);
- } else {
- return val;
- }
-}
-
// TODO: Template to allow keeping ncols in registers if they fit
static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
float * __restrict__ dst,
float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
float local_max = -INFINITY;
const int step_size = gridDim.x * blockDim.x;
+ __shared__ float shared_vals[32];
// Compute thread-local max
for (int col = col_start; col < p.ncols;) {
}
// Compute CTA-level max
- local_max = two_stage_warp_reduce_max(local_max);
+ local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
// Store CTA-level max to GMEM
if (tid == 0) {
} else {
local_max = -INFINITY;
}
- local_max = two_stage_warp_reduce_max(local_max);
+ local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
// Compute softmax dividends, accumulate divisor
float tmp_expf = 0.0f;
}
// Reduce divisor within CTA
- tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
+ tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
// Store CTA-level sum to GMEM
if (tid == 0) {
} else {
tmp_expf = 0.0f;
}
- tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
+ tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
// Divide dividend by global sum + store data
for (int col = col_start; col < p.ncols;) {
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
test_cases.emplace_back(new test_silu_back());
- for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
- for (bool v : {false, true}) {
- test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
- test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
+ for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f }) {
+ for (uint32_t n : { 64, 1025 }) {
+ for (bool v : { false, true }) {
+ test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
+ test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
+ }
+ test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
+ test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
}
- test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
- test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
// in-place tests
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true));
- for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
- test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
- test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
- test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
- test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
- test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
- test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
+ for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f }) {
+ for (uint32_t n : { 64, 1025 }) {
+ test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
+ test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
+ test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
+ test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
+ test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
+ test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
+ }
}
for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
for (bool multi_add : {false, true}) {
}
}
}
-
- test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
-
for (int64_t d_conv : {3, 4, 9}) {
for (int64_t d_inner: {1024, 1536, 2048}) {
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));