]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
SYCL: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators (llama/16613)
authorsafranowith <redacted>
Mon, 20 Oct 2025 08:08:32 +0000 (11:08 +0300)
committerGeorgi Gerganov <redacted>
Wed, 22 Oct 2025 09:58:11 +0000 (12:58 +0300)
* SYCL: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators

Clean up unrelated changes from previous commit

* Chore: remove empty lines and fix indentation

* Clean up: remove leftover blank lines and fix spacing

* chore: fix trailing whitespace and ensure final newline

* Cleanup: remove redundant declarations already defined in header

* Sync docs/ops.md with updated backend operation support

* docs: update ops.md after rebase

* docs: update ops.md - Vulkan supports SSM_CONV and SSM_SCAN

ggml/src/ggml-sycl/element_wise.cpp
ggml/src/ggml-sycl/element_wise.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp

index 58f5125c9cf6eb577937d45bd291835962d64b0e..810995d0cbf74277e45ebb49b5428b754a5d5205 100644 (file)
@@ -150,6 +150,26 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
     return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
 }
 
+template<typename T>
+static __dpct_inline__ T op_floor(T x) {
+    return sycl::floor(x);
+}
+
+template<typename T>
+static __dpct_inline__ T op_ceil(T x) {
+    return sycl::ceil(x);
+}
+
+template<typename T>
+static __dpct_inline__ T op_round(T x) {
+    return sycl::round(x);
+}
+
+template<typename T>
+static __dpct_inline__ T op_trunc(T x) {
+    return sycl::trunc(x);
+}
+
 template<typename T>
 static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
     SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
@@ -304,6 +324,34 @@ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl:
     }
 }
 
+template<typename T>
+static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_floor(x[i]);
+    }
+}
+
+template<typename T>
+static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_ceil(x[i]);
+    }
+}
+
+template<typename T>
+static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_round(x[i]);
+    }
+}
+
+template<typename T>
+static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_trunc(x[i]);
+    }
+}
+
 template<typename  T>
 static void upscale(const T  *x, T *dst, const int nb00, const int nb01,
                         const int nb02, const int nb03, const int ne10, const int ne11,
@@ -897,6 +945,58 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
         }, min_val, max_val);
 }
 
+static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, 256);
+            stream->parallel_for(
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
+                                  sycl::range<1>(256)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
+
+static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, 256);
+            stream->parallel_for(
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
+                                  sycl::range<1>(256)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
+
+static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, 256);
+            stream->parallel_for(
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
+                                  sycl::range<1>(256)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
+
+static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, 256);
+            stream->parallel_for(
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
+                                  sycl::range<1>(256)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
+
 static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
@@ -1122,3 +1222,23 @@ void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
     ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);
 }
+
+void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_floor(ctx, dst);
+}
+
+void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_ceil(ctx, dst);
+}
+
+void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_round(ctx, dst);
+}
+
+void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_trunc(ctx, dst);
+}
index ed96c55f75a7a19c11c63535e354d74a6a05f679..fcf93295cb2152a2ee7a745dc1e5c37db9b05349 100644 (file)
@@ -80,6 +80,10 @@ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
index a7e077ec8ebe0f24567062dc68e25db8265f4f89..1a007ffe2bca65899ece2df4c6668908ae1b9f8f 100644 (file)
@@ -3698,6 +3698,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
                 case GGML_UNARY_OP_ELU:
                     ggml_sycl_elu(ctx, dst);
                     break;
+                case GGML_UNARY_OP_FLOOR:
+                    ggml_sycl_floor(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_CEIL:
+                    ggml_sycl_ceil(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_ROUND:
+                    ggml_sycl_round(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_TRUNC:
+                    ggml_sycl_trunc(ctx, dst);
+                    break;
                 default:
                     return false;
             }
@@ -4262,6 +4274,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_UNARY_OP_SGN:
                 case GGML_UNARY_OP_ABS:
                 case GGML_UNARY_OP_ELU:
+                case GGML_UNARY_OP_FLOOR:
+                case GGML_UNARY_OP_CEIL:
+                case GGML_UNARY_OP_ROUND:
+                case GGML_UNARY_OP_TRUNC:
 #if defined (GGML_SYCL_F16)
                     return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
 #else