]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: add set rows for f32 and f16 (llama/14551)
authorAman Gupta <redacted>
Sat, 12 Jul 2025 13:31:38 +0000 (21:31 +0800)
committerGeorgi Gerganov <redacted>
Sat, 19 Jul 2025 14:47:23 +0000 (17:47 +0300)
* CUDA: add set rows for f32 and f16

* Review: change kernel params, use strides from host

* Use 1-d kernel

* Review: use int64_t for blockDim.x, rename nb->s for clarity

src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/set-rows.cu [new file with mode: 0644]
src/ggml-cuda/set-rows.cuh [new file with mode: 0644]

index 72406f0af36225dd4f6a4c15d01bd888d88fa5a8..88b17dd682c95845cc356c7e8de7370ce7457165 100644 (file)
@@ -43,6 +43,7 @@
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv.cuh"
 #include "ggml-cuda/gla.cuh"
+#include "ggml-cuda/set-rows.cuh"
 #include "ggml.h"
 
 #include <algorithm>
@@ -2230,6 +2231,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_GET_ROWS_BACK:
             ggml_cuda_op_get_rows_back(ctx, dst);
             break;
+        case GGML_OP_SET_ROWS:
+            ggml_cuda_op_set_rows(ctx, dst);
+            break;
         case GGML_OP_DUP:
             ggml_cuda_dup(ctx, dst);
             break;
@@ -3216,6 +3220,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             {
                 return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
             } break;
+        case GGML_OP_SET_ROWS:
+            {
+                return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
+                       op->src[0]->type == GGML_TYPE_F32 &&
+                       op->src[1]->type == GGML_TYPE_I64;
+            } break;
         case GGML_OP_CPY:
             {
                 ggml_type src0_type = op->src[0]->type;
diff --git a/src/ggml-cuda/set-rows.cu b/src/ggml-cuda/set-rows.cu
new file mode 100644 (file)
index 0000000..d8b3e63
--- /dev/null
@@ -0,0 +1,130 @@
+#include "set-rows.cuh"
+
+typedef void (*set_rows_kernel_t)(const char * src, char * dst);
+
+template<typename src_t, typename dst_t>
+__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {}
+
+template<>
+__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
+    *dst_h = __float2half(*src_f);
+}
+
+template<>
+__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
+    *dst_f = *src_f;
+}
+
+template<typename src_t, typename dst_t>
+static __global__ void k_set_rows(
+        const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+        const int64_t s01, const int64_t s02, const int64_t s03,
+        const int64_t s10, const int64_t s11, const int64_t s12,
+        const int64_t s1, const int64_t s2, const int64_t s3) {
+
+    const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
+    const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
+
+    if (i >= ne_total) {
+        return;
+    }
+
+    const int64_t i03 = i / (ne00 * ne01 * ne02);
+    const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
+    const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
+    const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
+
+    const int64_t i12 = i03 % ne12;
+    const int64_t i11 = i02 % ne11;
+    const int64_t i10 = i01;
+
+    const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
+
+    const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
+    dst_t * dst_row_ptr    = dst + dst_row*s1 + i02*s2 + i03*s3;
+
+    const src_t* src_elem = src0_row + i00;
+    dst_t* dst_elem = dst_row_ptr + i00;
+    set_rows_1(src_elem, dst_elem);
+}
+
+template<typename src_t, typename dst_t>
+static void set_rows_cuda(
+        const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+        const size_t nb01, const size_t nb02, const size_t nb03,
+        const size_t nb10, const size_t nb11, const size_t nb12,
+        const size_t nb1, const size_t nb2, const size_t nb3,
+        cudaStream_t stream) {
+
+    const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
+    const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
+    const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
+    const dim3 grid_size(num_blocks);
+
+
+    const int64_t s01 = nb01/sizeof(src_t);
+    const int64_t s02 = nb02/sizeof(src_t);
+    const int64_t s03 = nb03/sizeof(src_t);
+    const int64_t s10 = nb10/sizeof(int64_t);
+    const int64_t s11 = nb11/sizeof(int64_t);
+    const int64_t s12 = nb12/sizeof(int64_t);
+    const int64_t s1  = nb1/sizeof(dst_t);
+    const int64_t s2  = nb2/sizeof(dst_t);
+    const int64_t s3  = nb3/sizeof(dst_t);
+
+    if (ne_total > 0) {
+        k_set_rows<<<grid_size, block_size, 0, stream>>>(
+            src0_d, src1_d, dst_d,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            s01, s02, s03,
+            s10, s11, s12,
+            s1, s2, s3);
+    }
+}
+
+
+void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_I64);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const float * src0_d   = (const float *)src0->data;
+    const int64_t * src1_d = (const int64_t *)src1->data;
+
+    cudaStream_t stream = ctx.stream();
+
+
+
+    if (dst->type == GGML_TYPE_F32) {
+        set_rows_cuda(
+            src0_d, src1_d, (float*)dst->data,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            nb01, nb02, nb03,
+            nb10, nb11, nb12,
+            nb1, nb2, nb3,
+            stream
+        );
+    } else if (dst->type == GGML_TYPE_F16) {
+        set_rows_cuda(
+            src0_d, src1_d, (half*)dst->data,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            nb01, nb02, nb03,
+            nb10, nb11, nb12,
+            nb1, nb2, nb3,
+            stream
+        );
+    } else {
+        GGML_ABORT("unsupported type");
+    }
+}
diff --git a/src/ggml-cuda/set-rows.cuh b/src/ggml-cuda/set-rows.cuh
new file mode 100644 (file)
index 0000000..c140c08
--- /dev/null
@@ -0,0 +1,7 @@
+#pragma once
+
+#include "common.cuh"
+
+#define CUDA_SET_ROWS_BLOCK_SIZE 256
+
+void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);