]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: add FLOOR, CEIL, ROUND, TRUNC unary ops (llama/16917)
authormnehete32 <redacted>
Sun, 2 Nov 2025 03:12:57 +0000 (08:42 +0530)
committerGeorgi Gerganov <redacted>
Sun, 9 Nov 2025 16:30:22 +0000 (18:30 +0200)
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/unary.cu
src/ggml-cuda/unary.cuh

index 61a8f1df87de1872c4eb825c6b07c96f5cc51ae4..5667ec0c4d709be51f7b02b8775cc57f9acb61b3 100644 (file)
@@ -2499,6 +2499,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                 case GGML_UNARY_OP_XIELU:
                     ggml_cuda_op_xielu(ctx, dst);
                     break;
+                case GGML_UNARY_OP_FLOOR:
+                    ggml_cuda_op_floor(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_CEIL:
+                    ggml_cuda_op_ceil(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_ROUND:
+                    ggml_cuda_op_round(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_TRUNC:
+                    ggml_cuda_op_trunc(ctx, dst);
+                    break;
                 default:
                     return false;
             }
@@ -3769,6 +3781,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_EXP:
                 case GGML_UNARY_OP_ELU:
+                case GGML_UNARY_OP_FLOOR:
+                case GGML_UNARY_OP_CEIL:
+                case GGML_UNARY_OP_ROUND:
+                case GGML_UNARY_OP_TRUNC:
                     return ggml_is_contiguous(op->src[0]);
                 default:
                     return false;
index 5f0d3a6726aefd4806ca6cf54ceb63a6c9e49ac7..c1dc6ddbf8f81ce94e133d9c4307ee32bff30f05 100644 (file)
@@ -85,6 +85,22 @@ static __device__ __forceinline__ float op_elu(float x) {
     return (x > 0.f) ? x : expm1f(x);
 }
 
+static __device__ __forceinline__ float op_floor(float x) {
+    return floorf(x);
+}
+
+static __device__ __forceinline__ float op_ceil(float x) {
+    return ceilf(x);
+}
+
+static __device__ __forceinline__ float op_round(float x) {
+    return round(x);
+}
+
+static __device__ __forceinline__ float op_trunc(float x) {
+    return trunc(x);
+}
+
 template <float (*op)(float), typename T>
 static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -201,6 +217,22 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_unary<op_elu>(ctx, dst);
 }
+
+void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary<op_floor>(ctx, dst);
+}
+
+void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary<op_ceil>(ctx, dst);
+}
+
+void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary<op_round>(ctx, dst);
+}
+
+void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary<op_trunc>(ctx, dst);
+}
 /* gated ops */
 
 template <float (*op)(float), typename T>
index 6c738cefecfd21cc0113228044792f813a576c3c..2800c75ba3f7abc50f6625724ef79059f3971f95 100644 (file)
@@ -63,6 +63,14 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);