return a / b;
}
-
-
-template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
-static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
- const int ne0, const int ne1, const int ne2, const int ne3,
- const int ne10, const int ne11, const int ne12, const int ne13,
- /*int s0, */ const int s1, const int s2, const int s3,
- /*int s00,*/ const int s01, const int s02, const int s03,
- /*int s10,*/ const int s11, const int s12, const int s13,
- src1_ptrs... src1s) {
- const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
- const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
- const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
- const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
-
- if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+template <float (*bin_op)(const float, const float),
+ typename src0_t,
+ typename src1_t,
+ typename dst_t,
+ typename... src1_ptrs>
+static __global__ void k_bin_bcast(const src0_t * src0,
+ const src1_t * src1,
+ dst_t * dst,
+ const int ne0,
+ const int ne1,
+ const int ne2,
+ const uint3 ne3,
+ const uint3 ne10,
+ const uint3 ne11,
+ const uint3 ne12,
+ const uint3 ne13,
+ /*int s0, */ const int s1,
+ const int s2,
+ const int s3,
+ /*int s00,*/ const int s01,
+ const int s02,
+ const int s03,
+ /*int s10,*/ const int s11,
+ const int s12,
+ const int s13,
+ src1_ptrs... src1s) {
+ const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
+ const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
+ const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
+ const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
+
+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
return;
}
- const int i11 = i1 % ne11;
- const int i12 = i2 % ne12;
- const int i13 = i3 % ne13;
+ const uint32_t i11 = fastmodulo(i1, ne11);
+ const uint32_t i12 = fastmodulo(i2, ne12);
+ const uint32_t i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
- for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
- const int i10 = i0 % ne10;
+ for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
+ const uint32_t i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
}
}
-template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
-static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
- const int ne0, const int ne1, const int ne2,const int ne3,
- const int ne10, const int ne11, const int ne12, const int ne13,
- /*int s0, */ const int s1, const int s2, const int s3,
- /*int s00,*/ const int s01, const int s02, const int s03,
- /*int s10,*/ const int s11, const int s12, const int s13,
- src1_ptrs ... src1s) {
+template <float (*bin_op)(const float, const float),
+ typename src0_t,
+ typename src1_t,
+ typename dst_t,
+ typename... src1_ptrs>
+static __global__ void k_bin_bcast_unravel(const src0_t * src0,
+ const src1_t * src1,
+ dst_t * dst,
+ const uint3 ne0,
+ const uint3 ne1,
+ const uint3 ne2,
+ const uint32_t ne3,
+ const uint3 prod_012,
+ const uint3 prod_01,
+ const uint3 ne10,
+ const uint3 ne11,
+ const uint3 ne12,
+ const uint3 ne13,
+ /*int s0, */ const int s1,
+ const int s2,
+ const int s3,
+ /*int s00,*/ const int s01,
+ const int s02,
+ const int s03,
+ /*int s10,*/ const int s11,
+ const int s12,
+ const int s13,
+ src1_ptrs... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
- const int i3 = i/(ne2*ne1*ne0);
- const int i2 = (i/(ne1*ne0)) % ne2;
- const int i1 = (i/ne0) % ne1;
- const int i0 = i % ne0;
+ const uint32_t i3 = fastdiv(i, prod_012);
+ const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
+ const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
+ const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
- if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+ if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
return;
}
- const int i11 = i1 % ne11;
- const int i12 = i2 % ne12;
- const int i13 = i3 % ne13;
+ const int i11 = fastmodulo(i1, ne11);
+ const int i12 = fastmodulo(i2, ne12);
+ const int i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
- const int i10 = i0 % ne10;
+ const int i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
- int64_t ne10 = cne1[0];
- int64_t ne11 = cne1[1];
- int64_t ne12 = cne1[2];
- int64_t ne13 = cne1[3];
-
size_t nb0 = cnb[0];
size_t nb1 = cnb[1];
size_t nb2 = cnb[2];
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
- dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
- (ne1 + block_dims.y - 1) / block_dims.y,
+ dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);
+ const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
+ const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
+ const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
+ const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
+
if (block_nums.z > 65535) {
- int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
+ int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
+ const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
+ const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
+ const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
+ const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
+ const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
+
if constexpr (sizeof...(I) > 0) {
- k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
- <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00,*/ s01, s02, s03,
- /* s10,*/ s11, s12,s13,
- (const src1_t *) dst->src[I + 1]->data...);
+ k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
+ ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
- <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00,*/ s01, s02, s03,
- /* s10,*/ s11, s12,s13);
+ <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
+ ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12, s13);
}
} else {
+ const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
- k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
- <<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00,*/ s01, s02, s03,
- /* s10,*/ s11, s12,s13,
- (const src1_t *) dst->src[I + 1]->data...);
+ k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
- k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
- <<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00,*/ s01, s02, s03,
- /* s10,*/ s11, s12,s13);
+ k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12, s13);
}
}
}