}
// non-contiguous kernel (slow)
-static __global__ void concat_f32_non_cont(
+template <int dim>
+static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
+ concat_f32_non_cont(
const char * src0,
const char * src1,
char * dst,
uint64_t nb0,
uint64_t nb1,
uint64_t nb2,
- uint64_t nb3,
- int32_t dim) {
+ uint64_t nb3){
+ static_assert(dim >= 0 && dim <= 3);
+
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
- int64_t o[4] = {0, 0, 0, 0};
- o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
-
const float * x;
- for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+ for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
} else {
- x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
+ if constexpr (dim == 0) {
+ x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10);
+ } else if constexpr (dim == 1) {
+ x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10);
+ } else if constexpr (dim == 2) {
+ x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10);
+ } else if constexpr (dim == 3) {
+ x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10);
+ }
}
float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
}
} else {
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
- concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
- (const char *)src0->data,
- (const char *)src1->data,
- ( char *)dst->data,
+ auto launch_kernel = [&](auto dim) {
+ concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
+ (const char *) src0->data, (const char *) src1->data, (char *) dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
- dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
- dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
+ };
+ switch (dim) {
+ case 0:
+ launch_kernel(std::integral_constant<int, 0>{});
+ break;
+ case 1:
+ launch_kernel(std::integral_constant<int, 1>{});
+ break;
+ case 2:
+ launch_kernel(std::integral_constant<int, 2>{});
+ break;
+ case 3:
+ launch_kernel(std::integral_constant<int, 3>{});
+ break;
+ default:
+ GGML_ABORT("Invalid dim: %d", dim);
+ break;
+ }
}
}