]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
sycl : Fixes to broken builds and test-backend-ops (llama/10257)
authorAlberto Cabrera Pérez <redacted>
Wed, 13 Nov 2024 09:40:57 +0000 (09:40 +0000)
committerGeorgi Gerganov <redacted>
Fri, 15 Nov 2024 13:21:04 +0000 (15:21 +0200)
* Fixes broken build for the SYCL CUDA backend caused by non-explicit gemm call in outprod (merged in with RWKV6 in
Optimize RWKV6 Operator Naming and Implement Multi-core CPU/ SYCL Acceleration #10133)

* Marks permuted MUL_MAT as unsupported to be able to run test-backend-ops

* Fixes asserts in norm to fix debug builds.

ggml/src/ggml-sycl.cpp
ggml/src/ggml-sycl/norm.cpp
ggml/src/ggml-sycl/outprod.cpp

index 255bc64c6baddf13adb1d8e2a9820b88036c0e6a..2dba15d237e920ed06eae0c9b4fa2c2f1fe6cb4f 100644 (file)
@@ -4350,6 +4350,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 if (op->op == GGML_OP_MUL_MAT) {
                     a = op->src[0];
                     b = op->src[1];
+                    if (ggml_is_permuted(a) || ggml_is_permuted(b)) {
+                        // TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021
+                        return false;
+                    }
                 } else {
                     a = op->src[2];
                     b = op->src[1];
index b3159b9d1b94d63db4ea421171da0fce372d8f50..72d8fdb878c8de290739f339d8e9037cd99d8d31 100644 (file)
@@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
 
     const int nthreads = item_ct1.get_local_range(2);
     const int nwarps = nthreads / WARP_SIZE;
-    assert(nwarps % WARP_SIZE == 0);
     sycl::float2 mean_var = sycl::float2(0.f, 0.f);
 
     for (int col = tid; col < ncols; col += block_size) {
@@ -55,7 +54,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
     int end = start + group_size;
     const int nthreads = item_ct1.get_local_range(2);
     const int nwarps = nthreads / WARP_SIZE;
-    assert(nwarps % WARP_SIZE == 0);
     start += item_ct1.get_local_id(2);
     int nreduce = nwarps / WARP_SIZE;
 
@@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
     const int tid = item_ct1.get_local_id(2);
     const int nthreads = item_ct1.get_local_range(2);
     const int nwarps = nthreads / WARP_SIZE;
-    assert(nwarps % WARP_SIZE == 0);
     float tmp = 0.0f; // partial sum for thread in warp
 
     for (int col = tid; col < ncols; col += block_size) {
@@ -202,6 +199,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
     }
     else {
         const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
+        assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
         const sycl::range<3> block_dims(1, 1, work_group_size);
         /*
         DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
@@ -244,6 +242,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
     }
     else {
         const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
+        assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
         const sycl::range<3> block_dims(1, 1, work_group_size);
         /*
         DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
@@ -290,6 +289,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
     }
     else {
         const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
+        assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
         const sycl::range<3> block_dims(1, 1, work_group_size);
         /*
         DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
index c2779df0ecf91cc38b1a397f48c984e5df239ddc..e61cdc2ca5d5377fb2e9c8a4d9b414d05c894366 100644 (file)
@@ -1,4 +1,5 @@
 #include <sycl/sycl.hpp>
+#include <oneapi/mkl.hpp>
 #include "outprod.hpp"
 
 
@@ -39,7 +40,7 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
 
     try {
         // Perform matrix multiplication using oneMKL GEMM
-        oneapi::mkl::blas::gemm(*stream,
+        oneapi::mkl::blas::column_major::gemm(*stream,
             oneapi::mkl::transpose::nontrans, src1_op,
             ne0, ne1, ne01,
             alpha,