]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Fix WARP_SIZE=16 bug of Intel GPU (llama/8266)
authorluoyu-intel <redacted>
Fri, 5 Jul 2024 05:06:13 +0000 (05:06 +0000)
committerGeorgi Gerganov <redacted>
Mon, 8 Jul 2024 10:03:28 +0000 (13:03 +0300)
* fix group_norm ut

* split softmax

* fix softmax

* add concat support condition

* revert debug code

* move QK_WARP_SIZE to presets.hpp

src/CMakeLists.txt
src/ggml-sycl.cpp
src/ggml-sycl/backend.hpp
src/ggml-sycl/dmmv.cpp
src/ggml-sycl/norm.cpp
src/ggml-sycl/presets.hpp
src/ggml-sycl/softmax.cpp [new file with mode: 0644]
src/ggml-sycl/softmax.hpp [new file with mode: 0644]

index 08b71d410d82e27a6a622ca09b546db76cd5cb1f..8d96a04b57bebfa02c6133ced276ecb7a8259bf9 100644 (file)
@@ -490,7 +490,7 @@ if (GGML_SYCL)
         set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
         add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
     else()
-        add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
+        add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
     endif()
 
     file(GLOB   GGML_HEADERS_SYCL "ggml-sycl/*.hpp")
index dde55335bb6da79f1d02d75cea0931be3615b40e..053cc950a8a39e5ca5879230c4c630ed0a6979d0 100644 (file)
@@ -892,117 +892,6 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
     dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
 }
 
-
-template <bool vals_smem, int ncols_template, int block_size_template>
-static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
-                         const int nrows_y, const float scale, const float max_bias, const float m0,
-                         const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
-    const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
-
-    const int tid = item_ct1.get_local_id(2);
-    const int rowx = item_ct1.get_group(2);
-    const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
-
-    const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
-
-    const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
-    const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
-
-    float slope = 1.0f;
-
-    // ALiBi
-    if (max_bias > 0.0f) {
-        const uint32_t h = rowx/nrows_y; // head index
-
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-        slope = sycl::pow(base, float(exp));
-    }
-
-    float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
-    float max_val = -INFINITY;
-
-    for (int col0 = 0; col0 < ncols; col0 += block_size) {
-        const int col = col0 + tid;
-
-        if (ncols_template == 0 && col >= ncols) {
-            break;
-        }
-
-        const int ix = rowx*ncols + col;
-        const int iy = rowy*ncols + col;
-
-        const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
-
-        vals[col] = val;
-        max_val = sycl::max(max_val, val);
-    }
-
-    // find the max value in the block
-    max_val = warp_reduce_max(max_val, item_ct1);
-    if (block_size > WARP_SIZE) {
-        if (warp_id == 0) {
-            buf[lane_id] = -INFINITY;
-        }
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-
-        if (lane_id == 0) {
-            buf[warp_id] = max_val;
-        }
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-
-        max_val = buf[lane_id];
-        max_val = warp_reduce_max(max_val, item_ct1);
-    }
-
-    float tmp = 0.f;
-
-#pragma unroll
-    for (int col0 = 0; col0 < ncols; col0 += block_size) {
-        const int col = col0 + tid;
-                if (ncols_template == 0 && col >= ncols) {
-            break;
-        }
-
-        const float val = sycl::native::exp(vals[col] - max_val);
-        tmp += val;
-        vals[col] = val;
-    }
-
-    // find the sum of exps in the block
-    tmp = warp_reduce_sum(tmp, item_ct1);
-    if (block_size > WARP_SIZE) {
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-        if (warp_id == 0) {
-            buf[lane_id] = 0.f;
-        }
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-
-        if (lane_id == 0) {
-            buf[warp_id] = tmp;
-        }
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-
-        tmp = buf[lane_id];
-        tmp = warp_reduce_sum(tmp, item_ct1);
-    }
-
-    const float inv_sum = 1.f / tmp;
-
-#pragma unroll
-    for (int col0 = 0; col0 < ncols; col0 += block_size) {
-        const int col = col0 + tid;
-
-        if (ncols_template == 0 && col >= ncols) {
-            return;
-        }
-
-        const int idst = rowx*ncols + col;
-        dst[idst] = vals[col] * inv_sum;
-    }
-}
-
 static void scale_f32(const float * x, float * dst, const float scale, const int k,
                       const sycl::nd_item<3> &item_ct1) {
     const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -1890,106 +1779,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
                          });
 }
 
-template <bool vals_smem, int ncols_template, int block_size_template>
-static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
-                                   const int nrows_y, const float scale, const float max_bias, const float m0,
-                                   const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
-                                   const size_t n_local_scratch, queue_ptr stream) {
-    stream->submit([&](sycl::handler &cgh) {
-        sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
-
-        cgh.parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
-                soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
-                                                                             nrows_y, scale, max_bias, m0,
-                                                                             m1, n_head_log2, item_ct1,
-                                                                             local_buf_acc.get_pointer());
-            });
-    });
-}
-
-static void soft_max_f32_sycl(const float * x, const float * mask,
-                              float * dst, const int ncols_x, const int nrows_x,
-                              const int nrows_y, const float scale, const float max_bias,
-                              queue_ptr stream, int device) {
-    int nth = WARP_SIZE;
-    int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
-    while (nth < ncols_x && nth < max_block_size) nth *= 2;
-    if (nth>max_block_size) nth = max_block_size;
-
-    const sycl::range<3> block_dims(1, 1, nth);
-    const sycl::range<3> block_nums(1, 1, nrows_x);
-    const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
-
-    const uint32_t n_head_kv   = nrows_x/nrows_y;
-    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
-
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-    const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
-    if (n_local_scratch*sizeof(float) < local_mem_size) {
-        if (ncols_x > max_block_size) {
-            soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
-                                               max_bias, m0, m1, n_head_log2, block_nums,
-                                               block_dims, n_local_scratch, stream);
-            return;
-        }
-        switch (ncols_x) {
-            case 32:
-                soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
-                                                     max_bias, m0, m1, n_head_log2, block_nums,
-                                                     block_dims, n_local_scratch, stream);
-                break;
-            case 64:
-                soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
-                                                     max_bias, m0, m1, n_head_log2, block_nums,
-                                                     block_dims, n_local_scratch, stream);
-                break;
-            case 128:
-                soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
-                                                       max_bias, m0, m1, n_head_log2, block_nums,
-                                                       block_dims, n_local_scratch, stream);
-                break;
-            case 256:
-                soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
-                                                       max_bias, m0, m1, n_head_log2, block_nums,
-                                                       block_dims, n_local_scratch, stream);
-                break;
-            case 512:
-                soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
-                                                       max_bias, m0, m1, n_head_log2, block_nums,
-                                                       block_dims, n_local_scratch, stream);
-                break;
-            case 1024:
-                soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
-                                                         max_bias, m0, m1, n_head_log2, block_nums,
-                                                         block_dims, n_local_scratch, stream);
-                break;
-            case 2048:
-                soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
-                                                         max_bias, m0, m1, n_head_log2, block_nums,
-                                                         block_dims, n_local_scratch, stream);
-                break;
-            case 4096:
-                soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
-                                                         max_bias, m0, m1, n_head_log2, block_nums,
-                                                         block_dims, n_local_scratch, stream);
-                break;
-            default:
-                soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
-                                                   max_bias, m0, m1, n_head_log2, block_nums,
-                                                   block_dims, n_local_scratch, stream);
-                break;
-        }
-    } else {
-        soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
-                                            max_bias, m0, m1, n_head_log2, block_nums,
-                                            block_dims, WARP_SIZE, stream);
-    }
-}
-
 template <typename T>
 static void im2col_sycl(const float *x, T *dst, int IW, int IH,
                                 int OW, int OH, int KW, int KH, int IC,
@@ -3009,33 +2798,6 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const gg
     (void) src1_dd;
 }
 
-inline void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                  const ggml_tensor *src1, ggml_tensor *dst,
-                                  const float *src0_dd, const float *src1_dd,
-                                  float *dst_dd,
-                                  const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
-#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
-    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t nrows_x = ggml_nrows(src0);
-    const int64_t nrows_y = src0->ne[1];
-
-    float scale = 1.0f;
-    float max_bias = 0.0f;
-
-    memcpy(&scale, dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias, dst->op_params + 1, sizeof(float));
-
-    soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
-                      nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
-}
-
 inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
                                ggml_tensor *dst, const float *src0_dd,
                                const float *src1_dd, float *dst_dd,
@@ -5532,7 +5294,8 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
         case GGML_OP_CONCAT:
             {
                 ggml_type src0_type = op->src[0]->type;
-                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+                int dim = op->op_params[0];
+                return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
             } break;
         case GGML_OP_DUP:
         case GGML_OP_NONE:
index 3afa3391938f2c675ff6dbd47e1527bd3537b08a..2a789edfc909d589b5d4a046f6acc3f4b84ee7cb 100644 (file)
@@ -21,5 +21,6 @@
 #include "mmvq.hpp"
 #include "rope.hpp"
 #include "norm.hpp"
+#include "softmax.hpp"
 
 #endif // GGML_SYCL_BACKEND_HPP
index 927819281fd0a00d687492bef1ca39775b6ebaef..70a94fc16b99d024b634a128b75df78a99de3085 100644 (file)
@@ -3,6 +3,7 @@
 #include "dequantize.hpp"
 #include "presets.hpp"
 
+
 static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const sycl::half *x = (const sycl::half *)vx;
 
@@ -227,7 +228,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -346,7 +347,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -499,7 +500,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -633,7 +634,7 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -748,7 +749,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -873,10 +874,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
     const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
     const int block_num_y = (nrows + ny - 1) / ny;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, ny, WARP_SIZE);
+    const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
     stream->parallel_for(
         sycl::nd_range<3>(block_nums * block_dims, block_dims),
-        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
             dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
         });
 }
@@ -889,10 +890,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
     const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, ny, WARP_SIZE);
+    const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
     stream->parallel_for(
         sycl::nd_range<3>(block_nums * block_dims, block_dims),
-        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
             dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
         });
 }
@@ -905,10 +906,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
     const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, ny, WARP_SIZE);
+    const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
     stream->parallel_for(
         sycl::nd_range<3>(block_nums * block_dims, block_dims),
-        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
             dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
         });
 }
@@ -918,10 +919,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
                                              const int nrows,
                                              dpct::queue_ptr stream) {
     GGML_ASSERT(ncols % QK_K == 0);
-    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+    const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
     stream->parallel_for(
         sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
-        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
             dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
         });
 }
@@ -934,10 +935,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
     const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, ny, WARP_SIZE);
+    const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
     stream->parallel_for(
         sycl::nd_range<3>(block_nums * block_dims, block_dims),
-        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
             dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
         });
 }
index ed0fa7e31762b5a8e80bb18cb372fad6321f9f8f..e0c5dfeca96ca63b5b029c91870fba15b68f0958 100644 (file)
@@ -57,6 +57,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
     const int nwarps = nthreads / WARP_SIZE;
     assert(nwarps % WARP_SIZE == 0);
     start += item_ct1.get_local_id(2);
+    int nreduce = nwarps / WARP_SIZE;
 
     if (end >= ne_elements) {
         end = ne_elements;
@@ -87,7 +88,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
         */
         item_ct1.barrier();
         tmp = 0.f;
-        int nreduce = nwarps / WARP_SIZE;
         for (size_t i = 0; i < nreduce; i += 1)
         {
             tmp += s_sum[lane_id + i * WARP_SIZE];
@@ -122,7 +122,11 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
         better performance if there is no access to global memory.
         */
         item_ct1.barrier();
-        tmp = s_sum[lane_id];
+        tmp = 0.f;
+        for (size_t i = 0; i < nreduce; i += 1)
+        {
+            tmp += s_sum[lane_id + i * WARP_SIZE];
+        }
         tmp = warp_reduce_sum(tmp, item_ct1);
     }
 
index c09c75dc7c73c182221155b8f48d6b6828a4dfc7..15ddcac1fa14861ea29cefd8cca0b4ac5c8afca5 100644 (file)
@@ -62,4 +62,5 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
 
 #define MUL_MAT_SRC1_COL_STRIDE 128
 
+#define QK_WARP_SIZE 32
 #endif // GGML_SYCL_PRESETS_HPP
diff --git a/src/ggml-sycl/softmax.cpp b/src/ggml-sycl/softmax.cpp
new file mode 100644 (file)
index 0000000..e624b6b
--- /dev/null
@@ -0,0 +1,250 @@
+#include "norm.hpp"
+
+template <bool vals_smem, int ncols_template, int block_size_template>
+static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
+                         const int nrows_y, const float scale, const float max_bias, const float m0,
+                         const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
+    const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
+
+    const int tid = item_ct1.get_local_id(2);
+    const int rowx = item_ct1.get_group(2);
+    const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
+
+    const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
+
+    const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
+    const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
+    const int nthreads = block_size;
+    const int nwarps = nthreads / WARP_SIZE;
+    int nreduce = nwarps / WARP_SIZE;
+    float slope = 1.0f;
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        const uint32_t h = rowx/nrows_y; // head index
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = sycl::pow(base, float(exp));
+    }
+
+    float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
+    float max_val = -INFINITY;
+
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            break;
+        }
+
+        const int ix = rowx*ncols + col;
+        const int iy = rowy*ncols + col;
+
+        const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
+
+        vals[col] = val;
+        max_val = sycl::max(max_val, val);
+    }
+
+    // find the max value in the block
+    max_val = warp_reduce_max(max_val, item_ct1);
+    if (block_size > WARP_SIZE) {
+        if (warp_id == 0) {
+            buf[lane_id] = -INFINITY;
+            for (size_t i = 1; i < nreduce; i += 1)
+                buf[lane_id + i * WARP_SIZE] = -INFINITY;
+        }
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+        if (lane_id == 0) {
+            buf[warp_id] = max_val;
+        }
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+        max_val = buf[lane_id];
+        for (size_t i = 1; i < nreduce; i += 1)
+        {
+            max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
+        }
+        max_val = warp_reduce_max(max_val, item_ct1);
+    }
+
+    float tmp = 0.f;
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+                if (ncols_template == 0 && col >= ncols) {
+            break;
+        }
+
+        const float val = sycl::native::exp(vals[col] - max_val);
+        tmp += val;
+        vals[col] = val;
+    }
+
+    // find the sum of exps in the block
+    tmp = warp_reduce_sum(tmp, item_ct1);
+    if (block_size > WARP_SIZE) {
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+        if (warp_id == 0) {
+            buf[lane_id] = 0.f;
+            for (size_t i = 1; i < nreduce; i += 1)
+                buf[lane_id + i * WARP_SIZE] = 0.f;
+        }
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+        if (lane_id == 0) {
+            buf[warp_id] = tmp;
+        }
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+        tmp = buf[lane_id];
+        for (size_t i = 1; i < nreduce; i += 1)
+        {
+            tmp += buf[lane_id + i * WARP_SIZE];
+        }
+        tmp = warp_reduce_sum(tmp, item_ct1);
+    }
+
+    const float inv_sum = 1.f / tmp;
+
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            return;
+        }
+
+        const int idst = rowx*ncols + col;
+        dst[idst] = vals[col] * inv_sum;
+    }
+}
+
+template <bool vals_smem, int ncols_template, int block_size_template>
+static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
+                                   const int nrows_y, const float scale, const float max_bias, const float m0,
+                                   const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
+                                   const size_t n_local_scratch, queue_ptr stream) {
+    stream->submit([&](sycl::handler &cgh) {
+        sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
+
+        cgh.parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
+                                                                             nrows_y, scale, max_bias, m0,
+                                                                             m1, n_head_log2, item_ct1,
+                                                                             local_buf_acc.get_pointer());
+            });
+    });
+}
+
+static void soft_max_f32_sycl(const float * x, const float * mask,
+                              float * dst, const int ncols_x, const int nrows_x,
+                              const int nrows_y, const float scale, const float max_bias,
+                              queue_ptr stream, int device) {
+    int nth = WARP_SIZE;
+    int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
+    while (nth < ncols_x && nth < max_block_size) nth *= 2;
+    if (nth>max_block_size) nth = max_block_size;
+
+    const sycl::range<3> block_dims(1, 1, nth);
+    const sycl::range<3> block_nums(1, 1, nrows_x);
+    const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
+
+    const uint32_t n_head_kv   = nrows_x/nrows_y;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
+    if (n_local_scratch*sizeof(float) < local_mem_size) {
+        if (ncols_x > max_block_size) {
+            soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
+                                               max_bias, m0, m1, n_head_log2, block_nums,
+                                               block_dims, n_local_scratch, stream);
+            return;
+        }
+        switch (ncols_x) {
+            case 32:
+                soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
+                                                     max_bias, m0, m1, n_head_log2, block_nums,
+                                                     block_dims, n_local_scratch, stream);
+                break;
+            case 64:
+                soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
+                                                     max_bias, m0, m1, n_head_log2, block_nums,
+                                                     block_dims, n_local_scratch, stream);
+                break;
+            case 128:
+                soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
+                                                       max_bias, m0, m1, n_head_log2, block_nums,
+                                                       block_dims, n_local_scratch, stream);
+                break;
+            case 256:
+                soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
+                                                       max_bias, m0, m1, n_head_log2, block_nums,
+                                                       block_dims, n_local_scratch, stream);
+                break;
+            case 512:
+                soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
+                                                       max_bias, m0, m1, n_head_log2, block_nums,
+                                                       block_dims, n_local_scratch, stream);
+                break;
+            case 1024:
+                soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
+                                                         max_bias, m0, m1, n_head_log2, block_nums,
+                                                         block_dims, n_local_scratch, stream);
+                break;
+            case 2048:
+                soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
+                                                         max_bias, m0, m1, n_head_log2, block_nums,
+                                                         block_dims, n_local_scratch, stream);
+                break;
+            case 4096:
+                soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
+                                                         max_bias, m0, m1, n_head_log2, block_nums,
+                                                         block_dims, n_local_scratch, stream);
+                break;
+            default:
+                soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
+                                                   max_bias, m0, m1, n_head_log2, block_nums,
+                                                   block_dims, n_local_scratch, stream);
+                break;
+        }
+    } else {
+        soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
+                                            max_bias, m0, m1, n_head_log2, block_nums,
+                                            block_dims, WARP_SIZE, stream);
+    }
+}
+
+void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                  const ggml_tensor *src1, ggml_tensor *dst,
+                                  const float *src0_dd, const float *src1_dd,
+                                  float *dst_dd,
+                                  const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
+#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
+    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t nrows_x = ggml_nrows(src0);
+    const int64_t nrows_y = src0->ne[1];
+
+    float scale = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale, dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, dst->op_params + 1, sizeof(float));
+
+    soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
+        nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
+}
diff --git a/src/ggml-sycl/softmax.hpp b/src/ggml-sycl/softmax.hpp
new file mode 100644 (file)
index 0000000..bdb8f71
--- /dev/null
@@ -0,0 +1,24 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_SOFTMAX_HPP
+#define GGML_SYCL_SOFTMAX_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0,
+    const ggml_tensor *src1, ggml_tensor *dst,
+    const float *src0_dd, const float *src1_dd,
+    float *dst_dd,
+    const queue_ptr &main_stream);
+
+#endif // GGML_SYCL_SOFTMAX_HPP