}
}
-static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if (idx <= nrows) {
- offsets[idx] = idx * ncols;
- }
-}
#ifdef GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
cudaStream_t stream) {
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
- ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
int * temp_indices = temp_indices_alloc.get();
float * temp_keys = temp_keys_alloc.get();
- int * d_offsets = offsets_alloc.get();
static const int block_size = 256;
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
- const dim3 offset_grid((nrows + block_size - 1) / block_size);
- init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
+ auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
- d_offsets, d_offsets + 1, stream);
+ offset_iterator, offset_iterator + 1, stream);
}
} else {
if (nrows == 1) {
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
- dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+ dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
+ stream);
}
}
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
- ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+ ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
}
} else {
if (nrows == 1) {
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
- temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
- stream);
+ temp_indices, dst, ncols * nrows, nrows, offset_iterator,
+ offset_iterator + 1, stream);
}
}
}