From: Piotr Wilkin (ilintar) Date: Tue, 9 Dec 2025 19:28:57 +0000 (+0100) Subject: Add DIAG for CUDA (llama/17873) X-Git-Tag: upstream/1.8.3~158 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=2817582be254a2dd5c326063aeb9bb00f596aa3c;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp Add DIAG for CUDA (llama/17873) * Add DIAG for CUDA * Refactor parameters --- diff --git a/ggml/src/ggml-cuda/diag.cu b/ggml/src/ggml-cuda/diag.cu new file mode 100644 index 00000000..5cea2105 --- /dev/null +++ b/ggml/src/ggml-cuda/diag.cu @@ -0,0 +1,77 @@ +#include "convert.cuh" +#include "diag.cuh" +#include "ggml.h" + +template +static __global__ void diag_kernel(T * __restrict__ dst, + const T * __restrict__ src, + const int64_t ne0, + const int64_t ne1, + const int64_t ne2, + const int64_t ne3, + const int64_t total_elements) { + const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (global_idx >= total_elements) { + return; + } + + const int64_t i0 = global_idx % ne0; + const int64_t i1 = (global_idx / ne0) % ne1; + const int64_t i2 = (global_idx / (ne0 * ne1)) % ne2; + const int64_t i3 = global_idx / (ne0 * ne1 * ne2); + + const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0; + + if (i0 == i1) { + const int64_t batch_idx = i3 * ne2 + i2; + const int64_t src_idx = batch_idx * ne0 + i0; + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = ggml_cuda_cast(0); + } + GGML_UNUSED_VARS(ne3); +} + +void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + void * dst_d = dst->data; + const void * src0_d = src0->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + GGML_ASSERT(ne00 == ne0); + GGML_ASSERT(ne01 == 1); + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne3); + + const int64_t n_elems = ggml_nelements(dst); + const int64_t num_blocks = (n_elems + CUDA_DIAG_BLOCK_SIZE - 1) / CUDA_DIAG_BLOCK_SIZE; + + switch (dst->type) { + case GGML_TYPE_F32: + diag_kernel<<>>((float *) dst_d, (const float *) src0_d, ne0, + ne1, ne2, ne3, n_elems); + break; + case GGML_TYPE_F16: + diag_kernel<<>>((half *) dst_d, (const half *) src0_d, ne0, + ne1, ne2, ne3, n_elems); + break; + default: + GGML_ABORT("unsupported type"); + } +} diff --git a/ggml/src/ggml-cuda/diag.cuh b/ggml/src/ggml-cuda/diag.cuh new file mode 100644 index 00000000..7d73e6a8 --- /dev/null +++ b/ggml/src/ggml-cuda/diag.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_DIAG_BLOCK_SIZE 256 + +void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index d0463388..279679a4 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -20,6 +20,7 @@ #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" +#include "ggml-cuda/diag.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" @@ -2641,6 +2642,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: break; + case GGML_OP_DIAG: + ggml_cuda_op_diag(ctx, dst); + break; case GGML_OP_DIAG_MASK_INF: ggml_cuda_op_diag_mask_inf(ctx, dst); break; @@ -4624,6 +4628,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_FILL: case GGML_OP_CUMSUM: case GGML_OP_TRI: + case GGML_OP_DIAG: return true; case GGML_OP_SOLVE_TRI: return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;