]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
SYCL: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators (#16613)
authorsafranowith <redacted>
Mon, 20 Oct 2025 08:08:32 +0000 (11:08 +0300)
committerGitHub <redacted>
Mon, 20 Oct 2025 08:08:32 +0000 (11:08 +0300)
* SYCL: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators

Clean up unrelated changes from previous commit

* Chore: remove empty lines and fix indentation

* Clean up: remove leftover blank lines and fix spacing

* chore: fix trailing whitespace and ensure final newline

* Cleanup: remove redundant declarations already defined in header

* Sync docs/ops.md with updated backend operation support

* docs: update ops.md after rebase

* docs: update ops.md - Vulkan supports SSM_CONV and SSM_SCAN

docs/ops.md
docs/ops/SYCL.csv
docs/ops/Vulkan.csv
ggml/src/ggml-sycl/element_wise.cpp
ggml/src/ggml-sycl/element_wise.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp
tests/test-backend-ops.cpp

index 938efac815fc0535579edbc1d976b8e8fdb183af..226cd935d698a666f79535bf590e0eb57422e932 100644 (file)
@@ -22,7 +22,7 @@ Legend:
 |                           ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
 |                           ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                          ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
-|                             CEIL | â\9d\8c | â\9d\8c | â\9c\85 | â\9d\8c | â\9d\8c | â\9d\8c | â\9d\8c | ❌ | ❌ |
+|                             CEIL | â\9d\8c | â\9d\8c | â\9c\85 | â\9d\8c | â\9d\8c | â\9d\8c | â\9c\85 | ❌ | ❌ |
 |                            CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
 |                           CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ |
 |                             CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
@@ -42,7 +42,7 @@ Legend:
 |                              ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
 |                              EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
 |                   FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
-|                            FLOOR | â\9d\8c | â\9d\8c | â\9c\85 | â\9d\8c | â\9d\8c | â\9d\8c | â\9d\8c | ❌ | ❌ |
+|                            FLOOR | â\9d\8c | â\9d\8c | â\9c\85 | â\9d\8c | â\9d\8c | â\9d\8c | â\9c\85 | ❌ | ❌ |
 |                GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
 |                            GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
 |                        GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
@@ -84,7 +84,7 @@ Legend:
 |                             ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
 |                             ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
 |                        ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
-|                            ROUND | â\9d\8c | â\9d\8c | â\9c\85 | â\9d\8c | â\9d\8c | â\9d\8c | â\9d\8c | ❌ | ❌ |
+|                            ROUND | â\9d\8c | â\9d\8c | â\9c\85 | â\9d\8c | â\9d\8c | â\9d\8c | â\9c\85 | ❌ | ❌ |
 |                        RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                        RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                            SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
@@ -111,6 +111,6 @@ Legend:
 |                             TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
 |               TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
 |                         TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
-|                            TRUNC | â\9d\8c | â\9d\8c | â\9c\85 | â\9d\8c | â\9d\8c | â\9d\8c | â\9d\8c | ❌ | ❌ |
+|                            TRUNC | â\9d\8c | â\9d\8c | â\9c\85 | â\9d\8c | â\9d\8c | â\9d\8c | â\9c\85 | ❌ | ❌ |
 |                          UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
 |                            XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
index d7efa43cdf3dafb857ce86c03198f3dab4b6967f..bc6319f51fa8c76e17670ddf5ac1fcded5d55231 100644 (file)
 "SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
 "SYCL0","XIELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
 "SYCL0","XIELU","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
+"SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
+"SYCL0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
+"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
+"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
+"SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
+"SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
+"SYCL0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
+"SYCL0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
 "SYCL0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
 "SYCL0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
 "SYCL0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
 "SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
 "SYCL0","XIELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
 "SYCL0","XIELU","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
+"SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
+"SYCL0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
+"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
+"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
+"SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
+"SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
+"SYCL0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
+"SYCL0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
 "SYCL0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
 "SYCL0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
 "SYCL0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
index ea252577280d5ac939d727a9d6386dfa9bd4ba48..298c2a6ccd5fcb60dc5478f743b216298610e201 100644 (file)
 "Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=0","support","1","yes","Vulkan"
 "Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=1","support","1","yes","Vulkan"
 "Vulkan0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
-"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
-"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
-"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
+"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
+"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
+"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
 "Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=1,n_seqs=1","support","1","yes","Vulkan"
 "Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","1","yes","Vulkan"
 "Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
index 58f5125c9cf6eb577937d45bd291835962d64b0e..810995d0cbf74277e45ebb49b5428b754a5d5205 100644 (file)
@@ -150,6 +150,26 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
     return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
 }
 
+template<typename T>
+static __dpct_inline__ T op_floor(T x) {
+    return sycl::floor(x);
+}
+
+template<typename T>
+static __dpct_inline__ T op_ceil(T x) {
+    return sycl::ceil(x);
+}
+
+template<typename T>
+static __dpct_inline__ T op_round(T x) {
+    return sycl::round(x);
+}
+
+template<typename T>
+static __dpct_inline__ T op_trunc(T x) {
+    return sycl::trunc(x);
+}
+
 template<typename T>
 static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
     SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
@@ -304,6 +324,34 @@ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl:
     }
 }
 
+template<typename T>
+static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_floor(x[i]);
+    }
+}
+
+template<typename T>
+static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_ceil(x[i]);
+    }
+}
+
+template<typename T>
+static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_round(x[i]);
+    }
+}
+
+template<typename T>
+static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_trunc(x[i]);
+    }
+}
+
 template<typename  T>
 static void upscale(const T  *x, T *dst, const int nb00, const int nb01,
                         const int nb02, const int nb03, const int ne10, const int ne11,
@@ -897,6 +945,58 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
         }, min_val, max_val);
 }
 
+static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, 256);
+            stream->parallel_for(
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
+                                  sycl::range<1>(256)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
+
+static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, 256);
+            stream->parallel_for(
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
+                                  sycl::range<1>(256)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
+
+static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, 256);
+            stream->parallel_for(
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
+                                  sycl::range<1>(256)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
+
+static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, 256);
+            stream->parallel_for(
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
+                                  sycl::range<1>(256)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
+
 static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
@@ -1122,3 +1222,23 @@ void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
     ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);
 }
+
+void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_floor(ctx, dst);
+}
+
+void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_ceil(ctx, dst);
+}
+
+void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_round(ctx, dst);
+}
+
+void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_trunc(ctx, dst);
+}
index ed96c55f75a7a19c11c63535e354d74a6a05f679..fcf93295cb2152a2ee7a745dc1e5c37db9b05349 100644 (file)
@@ -80,6 +80,10 @@ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
index a7e077ec8ebe0f24567062dc68e25db8265f4f89..1a007ffe2bca65899ece2df4c6668908ae1b9f8f 100644 (file)
@@ -3698,6 +3698,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
                 case GGML_UNARY_OP_ELU:
                     ggml_sycl_elu(ctx, dst);
                     break;
+                case GGML_UNARY_OP_FLOOR:
+                    ggml_sycl_floor(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_CEIL:
+                    ggml_sycl_ceil(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_ROUND:
+                    ggml_sycl_round(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_TRUNC:
+                    ggml_sycl_trunc(ctx, dst);
+                    break;
                 default:
                     return false;
             }
@@ -4262,6 +4274,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_UNARY_OP_SGN:
                 case GGML_UNARY_OP_ABS:
                 case GGML_UNARY_OP_ELU:
+                case GGML_UNARY_OP_FLOOR:
+                case GGML_UNARY_OP_CEIL:
+                case GGML_UNARY_OP_ROUND:
+                case GGML_UNARY_OP_TRUNC:
 #if defined (GGML_SYCL_F16)
                     return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
 #else
index 82bb55ea0e18463517dc2e7b3e3f0bc5534fa755..fa98db2982ce77112c7930142afc1f804af3a87d 100644 (file)
@@ -3759,6 +3759,130 @@ struct test_clamp : public test_case {
     }
 };
 
+// GGML_OP_FLOOR
+struct test_floor : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_floor(ggml_type type = GGML_TYPE_F32,
+               std::array<int64_t, 4> ne = {10, 2, 2, 2})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_floor(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -10.0f, 10.0f);
+        }
+    }
+};
+
+// GGML_OP_CEIL
+struct test_ceil : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_ceil(ggml_type type = GGML_TYPE_F32,
+              std::array<int64_t, 4> ne = {10, 2, 2, 2})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_ceil(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -10.0f, 10.0f);
+        }
+    }
+};
+
+// GGML_OP_ROUND
+struct test_round : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_round(ggml_type type = GGML_TYPE_F32,
+               std::array<int64_t, 4> ne = {10, 2, 2, 2})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_round(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -10.0f, 10.0f);
+        }
+    }
+};
+
+// GGML_OP_TRUNC
+struct test_trunc : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_trunc(ggml_type type = GGML_TYPE_F32,
+               std::array<int64_t, 4> ne = {10, 2, 2, 2})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_trunc(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -10.0f, 10.0f);
+        }
+    }
+};
+
 // GGML_OP_DIAG_MASK_INF
 struct test_diag_mask_inf : public test_case {
     const ggml_type type;
@@ -6585,6 +6709,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_cos       (type));
         test_cases.emplace_back(new test_clamp     (type));
         test_cases.emplace_back(new test_leaky_relu(type));
+        test_cases.emplace_back(new test_floor     (type));
+        test_cases.emplace_back(new test_ceil      (type));
+        test_cases.emplace_back(new test_round     (type));
+        test_cases.emplace_back(new test_trunc     (type));
         test_cases.emplace_back(new test_sqr       (type, {7, 1, 5, 3}));
         test_cases.emplace_back(new test_sqrt      (type, {7, 1, 5, 3}));
         test_cases.emplace_back(new test_log       (type, {7, 1, 5, 3}));
@@ -6592,6 +6720,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_cos       (type, {7, 1, 5, 3}));
         test_cases.emplace_back(new test_clamp     (type, {7, 1, 5, 3}));
         test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_floor     (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_ceil      (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_round     (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_trunc     (type, {7, 1, 5, 3}));
     }
 
     test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));