size_t src1_stride_size = sizeof(cuda_t);
- dim3 block_dims(ne13, ne12);
- k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
+ const int threads_x = 16;
+ const int threads_y = 16;
+ dim3 block_dims(threads_x, threads_y);
+
+ dim3 grid_dims(
+ (ne13 + threads_x - 1) / threads_x,
+ (ne12 + threads_y - 1) / threads_y
+ );
+ k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
src0_ptr, src1_ptr, dst_t,
ptrs_src.get(), ptrs_dst.get(),
ne12, ne13,