]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sycl: fix norm kernels: l2_norm, group_norm, rms_norm by remove assert to support...
authorNeo Zhang <redacted>
Thu, 29 Jan 2026 01:20:22 +0000 (09:20 +0800)
committerGeorgi Gerganov <redacted>
Fri, 30 Jan 2026 11:49:29 +0000 (13:49 +0200)
Co-authored-by: Neo Zhang Jianyu <redacted>
src/ggml-sycl/ggml-sycl.cpp
src/ggml-sycl/norm.cpp

index ce2f0d41c96138db1b434a4149fead38c1c94f3d..3a4c092af5d1858ed194653a29b6796b25a6660f 100644 (file)
@@ -4606,14 +4606,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
 #endif
         case GGML_OP_NORM:
-            return true;
         case GGML_OP_L2_NORM:
         case GGML_OP_GROUP_NORM:
-            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_RMS_NORM:
-            return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
+            return true;
         case GGML_OP_RMS_NORM_BACK:
-            return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
+            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_SCALE:
             return true;
         case GGML_OP_CONT:
index 823d3a4828cc925cdeea2cb5c5434d51f371d1f5..00702b5d09cd6419bc651c5978482e31b9ad54fb 100644 (file)
@@ -251,7 +251,6 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
         const float eps, queue_ptr stream, int device) {
 
     const sycl::range<3> global_dims(nsamples, nchannels, nrows);
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
         const sycl::range<3> block_dims(1, 1, WARP_SIZE);
         stream->submit([&](sycl::handler& cgh) {
@@ -334,7 +333,6 @@ static void group_norm_f32_sycl(const float* x, float* dst,
 
 static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
         const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
     // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
 
     const sycl::range<3> global_dims(nsamples, nchannels, nrows);
@@ -374,7 +372,6 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
 static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
     const int nrows, const float eps,
     queue_ptr stream, int device) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
     // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
     if (ncols < 1024) {
         const sycl::range<3> block_dims(1, 1, WARP_SIZE);