]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: add roll (llama/14919)
authorAman Gupta <redacted>
Tue, 29 Jul 2025 06:45:18 +0000 (14:45 +0800)
committerGeorgi Gerganov <redacted>
Sat, 2 Aug 2025 14:51:21 +0000 (17:51 +0300)
* CUDA: add roll

* Make everything const, use __restrict__

src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/roll.cu [new file with mode: 0644]
src/ggml-cuda/roll.cuh [new file with mode: 0644]

index 03c380897cd8a9e29e3a44d2343add84a4abad7e..1f785796014bd90392445fb27e262fcaa366e2c6 100644 (file)
@@ -31,6 +31,7 @@
 #include "ggml-cuda/pool2d.cuh"
 #include "ggml-cuda/quantize.cuh"
 #include "ggml-cuda/rope.cuh"
+#include "ggml-cuda/roll.cuh"
 #include "ggml-cuda/scale.cuh"
 #include "ggml-cuda/softmax.cuh"
 #include "ggml-cuda/ssm-conv.cuh"
@@ -2419,6 +2420,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_ROPE_BACK:
             ggml_cuda_op_rope_back(ctx, dst);
             break;
+        case GGML_OP_ROLL:
+            ggml_cuda_op_roll(ctx, dst);
+            break;
         case GGML_OP_IM2COL:
             ggml_cuda_op_im2col(ctx, dst);
             break;
@@ -3411,6 +3415,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
             return max_bias == 0.0f;
         }
+        case GGML_OP_ROLL:
+            if(op->src[0]->type == GGML_TYPE_F32) {
+                return true;
+            }
+            return false;
         case GGML_OP_ROPE:
         case GGML_OP_ROPE_BACK: {
             return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
diff --git a/src/ggml-cuda/roll.cu b/src/ggml-cuda/roll.cu
new file mode 100644 (file)
index 0000000..a339dfc
--- /dev/null
@@ -0,0 +1,67 @@
+#include "ggml-cuda/common.cuh"
+#include "roll.cuh"
+
+static __forceinline__ __device__ int64_t wrap_index(const int64_t idx, const int64_t ne) {
+    if (idx < 0) {
+        return idx + ne;
+    }
+    if (idx >= ne) {
+        return idx - ne;
+    }
+    return idx;
+}
+
+static __global__ void roll_f32_cuda(const float * __restrict__ src,
+                                     float * __restrict__ dst,
+                                     const int64_t ne00,
+                                     const int64_t ne01,
+                                     const int64_t ne02,
+                                     const int64_t ne03,
+                                     const int     s0,
+                                     const int     s1,
+                                     const int     s2,
+                                     const int     s3) {
+    const int64_t idx        = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
+    const int64_t n_elements = ne00 * ne01 * ne02 * ne03;
+
+    if (idx >= n_elements) {
+        return;
+    }
+
+    const int64_t i0 = idx % ne00;
+    const int64_t i1 = (idx / ne00) % ne01;
+    const int64_t i2 = (idx / (ne00 * ne01)) % ne02;
+    const int64_t i3 = (idx / (ne00 * ne01 * ne02)) % ne03;
+
+    const int64_t d0 = wrap_index(i0 - s0, ne00);
+    const int64_t d1 = wrap_index(i1 - s1, ne01);
+    const int64_t d2 = wrap_index(i2 - s2, ne02);
+    const int64_t d3 = wrap_index(i3 - s3, ne03);
+
+    dst[i3 * (ne00 * ne01 * ne02) + i2 * (ne01 * ne00) + i1 * ne00 + i0] =
+        src[d3 * (ne00 * ne01 * ne02) + d2 * (ne01 * ne00) + d1 * ne00 + d0];
+}
+
+void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    int s0 = dst->op_params[0];
+    int s1 = dst->op_params[1];
+    int s2 = dst->op_params[2];
+    int s3 = dst->op_params[3];
+
+    const ggml_tensor * src0   = dst->src[0];
+    const float *       src0_d = (const float *) dst->src[0]->data;
+    float *             dst_d  = (float *) dst->data;
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst));
+
+    cudaStream_t stream = ctx.stream();
+
+    int64_t sz         = (ne00 * ne01 * ne02 * ne03);
+    int64_t num_blocks = (sz + CUDA_ROLL_BLOCK_SIZE - 1) / CUDA_ROLL_BLOCK_SIZE;
+
+    roll_f32_cuda<<<num_blocks, CUDA_ROLL_BLOCK_SIZE, 0, stream>>>(
+        src0_d, dst_d, ne00, ne01, ne02, ne03, s0, s1, s2, s3);
+}
diff --git a/src/ggml-cuda/roll.cuh b/src/ggml-cuda/roll.cuh
new file mode 100644 (file)
index 0000000..322d554
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_ROLL_BLOCK_SIZE 256
+
+void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst);