]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : add ELU support (#14657)
authorYavor Ivanov <redacted>
Sun, 13 Jul 2025 09:33:16 +0000 (02:33 -0700)
committerGitHub <redacted>
Sun, 13 Jul 2025 09:33:16 +0000 (11:33 +0200)
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/unary.cu
ggml/src/ggml-cuda/unary.cuh

index 1478245998a3dd17793e4e9ec42c7b9e06e83bfa..c7222207efed626ea0c9ebba9953204e74e092e1 100644 (file)
@@ -2303,6 +2303,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                 case GGML_UNARY_OP_EXP:
                     ggml_cuda_op_exp(ctx, dst);
                     break;
+                case GGML_UNARY_OP_ELU:
+                    ggml_cuda_op_elu(ctx, dst);
+                    break;
                 default:
                     return false;
             }
@@ -3116,6 +3119,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_EXP:
+                case GGML_UNARY_OP_ELU:
                     return ggml_is_contiguous(op->src[0]);
                 default:
                     return false;
index f9c7b83c40d1bb83786ff04be0c1299655628f28..91c830c4dacc3816778650620fcc5eaae8fb5462 100644 (file)
@@ -83,6 +83,10 @@ static __device__ __forceinline__ float op_log(float x) {
     return logf(x);
 }
 
+static __device__ __forceinline__ float op_elu(float x) {
+    return (x > 0.f) ? x : expm1f(x);
+}
+
 template <float (*op)(float), typename T>
 static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -196,6 +200,9 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_unary<op_log>(ctx, dst);
 }
 
+void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary<op_elu>(ctx, dst);
+}
 /* gated ops */
 
 template <float (*op)(float), typename T>
index 289d690e5cff6c0b63fed83e5da42bd846a52104..cb14d16f8f3f56b82701b4598970581725bf86ee 100644 (file)
@@ -59,6 +59,8 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);