#endif // FP16_AVAILABLE
}
+enum class block_reduce_method {
+ MAX,
+ SUM,
+};
+
+template<block_reduce_method method_t, typename T>
+struct block_reduce_policy;
+
+template <typename T, typename... Ts>
+inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
+
+template<typename...>
+inline constexpr bool ggml_cuda_dependent_false_v = false;
+
+template <typename T> struct block_reduce_policy<block_reduce_method::SUM, T> {
+ static __device__ T reduce(T val) {
+ if constexpr(is_any<T, float, float2, half2, int>) {
+ return warp_reduce_sum(val);
+ } else {
+ static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
+ }
+ }
+
+ static __device__ T sentinel() {
+ if constexpr (std::is_same_v<T, float>) {
+ return 0.0f;
+ } else if constexpr (std::is_same_v<T, float2>) {
+ return make_float2(0.0f, 0.0f);
+ } else if constexpr (std::is_same_v<T, half2>) {
+ return make_half2(0.0f, 0.0f);
+ } else if constexpr (std::is_same_v<T, int>) {
+ return 0;
+ } else {
+ static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
+ }
+ }
+};
+
+template <typename T> struct block_reduce_policy<block_reduce_method::MAX, T> {
+ static __device__ T reduce(T val) {
+ if constexpr (is_any<T, float, half2>) {
+ return warp_reduce_max(val);
+ } else {
+ static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
+ }
+ }
+
+ static __device__ T sentinel() {
+ if constexpr (std::is_same_v<T, float>) {
+ return -INFINITY;
+ } else if constexpr (std::is_same_v<T, half2>) {
+ return make_half2(-INFINITY, -INFINITY);
+ } else {
+ static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
+ }
+ }
+};
+
+template <block_reduce_method reduce_method_t, const unsigned int block_size_template = 0, typename T>
+static __device__ T block_reduce(T val, T * shared_vals) {
+ val = block_reduce_policy<reduce_method_t, T>::reduce(val);
+ const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
+ if (block_size > WARP_SIZE) {
+ assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0);
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ shared_vals[warp_id] = val;
+ }
+ __syncthreads();
+ val = block_reduce_policy<reduce_method_t, T>::sentinel();
+ if (lane_id < (static_cast<int>(block_size) / WARP_SIZE)) {
+ val = shared_vals[lane_id];
+ }
+ return block_reduce_policy<reduce_method_t, T>::reduce(val);
+ }
+
+ return val;
+}
+
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#ifdef FP16_AVAILABLE
}
// sum up partial sums
- mean_var = warp_reduce_sum(mean_var);
- if constexpr (block_size > WARP_SIZE) {
- static_assert(block_size == 1024, "unexpected block_size");
- __shared__ float2 s_sum[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = mean_var;
- }
- __syncthreads();
- mean_var = s_sum[lane_id];
- mean_var = warp_reduce_sum(mean_var);
- }
+ extern __shared__ float2 s_sum2[];
+ mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2);
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
tmp += x[j];
}
- tmp = warp_reduce_sum(tmp);
- if constexpr (block_size > WARP_SIZE) {
- static_assert(block_size == 1024, "unexpected block_size");
- __shared__ float s_sum[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = tmp;
- }
- __syncthreads();
- tmp = s_sum[lane_id];
- tmp = warp_reduce_sum(tmp);
- }
+ extern __shared__ float s_sum[];
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float mean = tmp / group_size;
tmp = 0.0f;
tmp += xi * xi;
}
- tmp = warp_reduce_sum(tmp);
- if (block_size > WARP_SIZE) {
- __shared__ float s_sum[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = tmp;
- }
- __syncthreads();
- tmp = s_sum[lane_id];
- tmp = warp_reduce_sum(tmp);
- }
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float variance = tmp / group_size;
const float scale = rsqrtf(variance + eps);
}
// sum up partial sums
- tmp = warp_reduce_sum(tmp);
- if constexpr (block_size > WARP_SIZE) {
- static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
- __shared__ float s_sum[32];
- const int warp_id = tid / WARP_SIZE;
- const int lane_id = tid % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = tmp;
- }
- __syncthreads();
- tmp = 0.0f;
- if (lane_id < (block_size / WARP_SIZE)) {
- tmp = s_sum[lane_id];
- }
- tmp = warp_reduce_sum(tmp);
- }
+ extern __shared__ float s_sum[];
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
}
// sum up partial sums
- tmp = warp_reduce_sum(tmp);
- if constexpr (block_size > WARP_SIZE) {
- static_assert(block_size == 1024, "unexpected block_size");
- __shared__ float s_sum[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- s_sum[warp_id] = tmp;
- }
- __syncthreads();
- tmp = s_sum[lane_id];
- tmp = warp_reduce_sum(tmp);
- }
+ extern __shared__ float s_sum[];
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
- norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
} else {
const dim3 block_dims(1024, 1, 1);
- group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+ group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);
}
}
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
- rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
- rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
+ rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
+ rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
}
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
- rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
+ rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
+ rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
- l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
-
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
extern __shared__ float data_soft_max_f32[];
}
// find the max value in the block
- max_val = warp_reduce_max(max_val);
- if (block_size > WARP_SIZE) {
- if (warp_id == 0) {
- buf_iw[lane_id] = -INFINITY;
- }
- __syncthreads();
-
- if (lane_id == 0) {
- buf_iw[warp_id] = max_val;
- }
- __syncthreads();
-
- max_val = buf_iw[lane_id];
- max_val = warp_reduce_max(max_val);
- }
+ max_val = block_reduce<block_reduce_method::MAX, block_size_template>(max_val, buf_iw);
float tmp = 0.0f; // partial sum
}
// find the sum of exps in the block
- tmp = warp_reduce_sum(tmp);
- if (block_size > WARP_SIZE) {
- __syncthreads();
- if (warp_id == 0) {
- buf_iw[lane_id] = 0.0f;
- }
- __syncthreads();
-
- if (lane_id == 0) {
- buf_iw[warp_id] = tmp;
- }
- __syncthreads();
-
- tmp = buf_iw[lane_id];
- tmp = warp_reduce_sum(tmp);
- }
+ tmp = block_reduce<block_reduce_method::SUM, block_size_template>(tmp, buf_iw);
if (sinks) {
tmp += expf(sinks[i02] - max_val);
}
}
-
-// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated
-static __device__ float two_stage_warp_reduce_max(float val) {
- val = warp_reduce_max(val);
- if (blockDim.x > WARP_SIZE) {
- assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
- __shared__ float local_vals[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- local_vals[warp_id] = val;
- }
- __syncthreads();
- val = -INFINITY;
- if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
- val = local_vals[lane_id];
- }
- return warp_reduce_max(val);
- } else {
- return val;
- }
-}
-
-static __device__ float two_stage_warp_reduce_sum(float val) {
- val = warp_reduce_sum(val);
- if (blockDim.x > WARP_SIZE) {
- assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
- __shared__ float local_vals[32];
- const int warp_id = threadIdx.x / WARP_SIZE;
- const int lane_id = threadIdx.x % WARP_SIZE;
- if (lane_id == 0) {
- local_vals[warp_id] = val;
- }
- __syncthreads();
- val = 0.0f;
- if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
- val = local_vals[lane_id];
- }
- return warp_reduce_sum(val);
- } else {
- return val;
- }
-}
-
// TODO: Template to allow keeping ncols in registers if they fit
static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
float * __restrict__ dst,
float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
float local_max = -INFINITY;
const int step_size = gridDim.x * blockDim.x;
+ __shared__ float shared_vals[32];
// Compute thread-local max
for (int col = col_start; col < p.ncols;) {
}
// Compute CTA-level max
- local_max = two_stage_warp_reduce_max(local_max);
+ local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
// Store CTA-level max to GMEM
if (tid == 0) {
} else {
local_max = -INFINITY;
}
- local_max = two_stage_warp_reduce_max(local_max);
+ local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
// Compute softmax dividends, accumulate divisor
float tmp_expf = 0.0f;
}
// Reduce divisor within CTA
- tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
+ tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
// Store CTA-level sum to GMEM
if (tid == 0) {
} else {
tmp_expf = 0.0f;
}
- tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
+ tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
// Divide dividend by global sum + store data
for (int col = col_start; col < p.ncols;) {