]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
SYCL: Add COUNT_EQUAL operator support (llama/15991)
authoryael-works <redacted>
Mon, 15 Sep 2025 16:51:35 +0000 (19:51 +0300)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
* SYCL: Add COUNT_EQUAL operator support (rebased on master)

* SYCL: remove duplicate op_count_equal definition

* tests: remove test_count_equal_typed and use test_count_equal for all cases

* tests: keep only I32 case for COUNT_EQUAL as suggested

* tests: keep only I32 case for COUNT_EQUAL as requested

src/ggml-sycl/binbcast.cpp
src/ggml-sycl/binbcast.hpp
src/ggml-sycl/ggml-sycl.cpp

index 0a3883ae1eda57017c864be9bc60ab231be8cdce..e0a1de0f322638caac16f159613dc111e29180bc 100644 (file)
@@ -303,6 +303,10 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
     ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
 }
 
+inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
+}
+
 inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
 
     ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
@@ -328,6 +332,11 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     ggml_sycl_op_sub(ctx, dst);
 }
 
+void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
+    ggml_sycl_op_count_equal(ctx, dst);
+}
+
 void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
     ggml_sycl_op_mul(ctx, dst);
index 9cce0f053a5826949a5f354fb2b9877b5eeaa374..34c4064f5287fbb656028245c6d312b2012f8965 100644 (file)
@@ -16,6 +16,12 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
     return a - b;
 }
 
+static __dpct_inline__ float op_count_equal(const float a, const float b) {
+    return (a == b) ? 1.0f : 0.0f;
+}
+
+void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
 static __dpct_inline__ float op_mul(const float a, const float b) {
     return a * b;
 }
index e06ec613fc81f760d65ecc487fe24159776a7b8a..9404e3ff4ad9bdf450b2ef4445dcd8d608468f38 100644 (file)
@@ -3577,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_SUB:
             ggml_sycl_sub(ctx, dst);
             break;
+        case GGML_OP_COUNT_EQUAL:
+            ggml_sycl_count_equal(ctx, dst);
+            break;
         case GGML_OP_ACC:
             ggml_sycl_acc(ctx, dst);
             break;
@@ -4356,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_ADD:
         case GGML_OP_ADD1:
         case GGML_OP_SUB:
+        case GGML_OP_COUNT_EQUAL:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_REPEAT: