From: yael-works Date: Mon, 15 Sep 2025 16:51:35 +0000 (+0300) Subject: SYCL: Add COUNT_EQUAL operator support (llama/15991) X-Git-Tag: v0.9.1~26 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=6cd764a3740830b0b333959fe63fa4e99e84b736;p=pkg%2Fggml%2Fsources%2Fggml SYCL: Add COUNT_EQUAL operator support (llama/15991) * SYCL: Add COUNT_EQUAL operator support (rebased on master) * SYCL: remove duplicate op_count_equal definition * tests: remove test_count_equal_typed and use test_count_equal for all cases * tests: keep only I32 case for COUNT_EQUAL as suggested * tests: keep only I32 case for COUNT_EQUAL as requested --- diff --git a/src/ggml-sycl/binbcast.cpp b/src/ggml-sycl/binbcast.cpp index 0a3883ae..e0a1de0f 100644 --- a/src/ggml-sycl/binbcast.cpp +++ b/src/ggml-sycl/binbcast.cpp @@ -303,6 +303,10 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); } +inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); @@ -328,6 +332,11 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_sub(ctx, dst); } +void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_count_equal(ctx, dst); +} + void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_mul(ctx, dst); diff --git a/src/ggml-sycl/binbcast.hpp b/src/ggml-sycl/binbcast.hpp index 9cce0f05..34c4064f 100644 --- a/src/ggml-sycl/binbcast.hpp +++ b/src/ggml-sycl/binbcast.hpp @@ -16,6 +16,12 @@ static __dpct_inline__ float op_sub(const float a, const float b) { return a - b; } +static __dpct_inline__ float op_count_equal(const float a, const float b) { + return (a == b) ? 1.0f : 0.0f; +} + +void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + static __dpct_inline__ float op_mul(const float a, const float b) { return a * b; } diff --git a/src/ggml-sycl/ggml-sycl.cpp b/src/ggml-sycl/ggml-sycl.cpp index e06ec613..9404e3ff 100644 --- a/src/ggml-sycl/ggml-sycl.cpp +++ b/src/ggml-sycl/ggml-sycl.cpp @@ -3577,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_SUB: ggml_sycl_sub(ctx, dst); break; + case GGML_OP_COUNT_EQUAL: + ggml_sycl_count_equal(ctx, dst); + break; case GGML_OP_ACC: ggml_sycl_acc(ctx, dst); break; @@ -4356,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ADD: case GGML_OP_ADD1: case GGML_OP_SUB: + case GGML_OP_COUNT_EQUAL: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_REPEAT: