]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
SYCL : SOFTMAX F16 mask support and other fixes (llama/11261)
authorAkarshan Biswas <redacted>
Tue, 28 Jan 2025 09:56:58 +0000 (15:26 +0530)
committerGeorgi Gerganov <redacted>
Mon, 3 Feb 2025 20:00:57 +0000 (22:00 +0200)
Implemented ggml_sycl_op_soft_max() F16 src1(mask) support for which a pragma deprecation warning was added during #5021.
To do this, had to decouple it from ggml_sycl_op_flatten which always considered src1 to be of fp32 type(many OP functions are dependent on it).

* SYCL: SOFTMAX F16 mask support and other fixes

* test-backend-ops: Add F16 mask test cases

ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/softmax.cpp
ggml/src/ggml-sycl/softmax.hpp

index ed4d8bb8baee1fc9df408ae991faae68e52ec28a..2984ed82e8a7c649ba4e4c13853a65ab1ebfdaee 100644 (file)
@@ -3878,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
     ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
 }
 
-static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max);
-}
-
 static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
     ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
@@ -4090,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
             ggml_sycl_diag_mask_inf(ctx, dst);
             break;
         case GGML_OP_SOFT_MAX:
-            ggml_sycl_soft_max(ctx, dst);
+            ggml_sycl_op_soft_max(ctx, dst);
             break;
         case GGML_OP_ROPE:
             ggml_sycl_rope(ctx, dst);
index a9b3fce0dc430a6e6e8daa148d9bb1fc9ff81074..563e0655f55273edf63d2ab617114537f6fb24e8 100644 (file)
@@ -1,7 +1,7 @@
-#include "norm.hpp"
+#include "softmax.hpp"
 
-template <bool vals_smem, int ncols_template, int block_size_template>
-static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
+template <bool vals_smem, int ncols_template, int block_size_template, typename T>
+static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
                          const int nrows_y, const float scale, const float max_bias, const float m0,
                          const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
     const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
         slope = sycl::pow(base, float(exp));
     }
 
-    float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
+    float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
     float max_val = -INFINITY;
 
     for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
         const int ix = rowx*ncols + col;
         const int iy = rowy*ncols + col;
 
-        const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
+        const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f);
 
         vals[col] = val;
         max_val = sycl::max(max_val, val);
@@ -65,7 +65,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
         item_ct1.barrier(sycl::access::fence_space::local_space);
         max_val = buf[lane_id];
         for (size_t i = 1; i < nreduce; i += 1) {
-            max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
+            max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]);
         }
         max_val = warp_reduce_max(max_val, item_ct1);
     }
@@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
     }
 }
 
-template <bool vals_smem, int ncols_template, int block_size_template>
-static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
+template <bool vals_smem, int ncols_template, int block_size_template, typename T>
+static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par,
                                    const int nrows_y, const float scale, const float max_bias, const float m0,
                                    const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
                                    const size_t n_local_scratch, queue_ptr stream) {
@@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
     });
 }
 
-static void soft_max_f32_sycl(const float * x, const float * mask,
+template<typename T>
+static void soft_max_f32_sycl(const float * x, const T * mask,
                               float * dst, const int ncols_x, const int nrows_x,
                               const int nrows_y, const float scale, const float max_bias,
                               queue_ptr stream, int device) {
@@ -223,22 +224,16 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
     }
 }
 
-void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                  const ggml_tensor *src1, ggml_tensor *dst,
-                                  const float *src0_dd, const float *src1_dd,
-                                  float *dst_dd,
-                                  const queue_ptr &main_stream) {
+void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
-#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
-    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+    GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t nrows_x = ggml_nrows(src0);
-    const int64_t nrows_y = src0->ne[1];
+    const int64_t ne00 = dst->src[0]->ne[0];
+    const int64_t nrows_x = ggml_nrows(dst->src[0]);
+    const int64_t nrows_y = dst->src[0]->ne[1];
 
     float scale = 1.0f;
     float max_bias = 0.0f;
@@ -246,6 +241,21 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s
     memcpy(&scale, dst->op_params + 0, sizeof(float));
     memcpy(&max_bias, dst->op_params + 1, sizeof(float));
 
-    soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
-        nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
+    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
+    float * dst_dd = static_cast<float *>(dst->data);
+
+    ggml_sycl_set_device(ctx.device);
+    dpct::queue_ptr main_stream = ctx.stream();
+
+    if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
+        const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
+        soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
+                          main_stream, ctx.device);
+    } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
+        const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
+        soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
+    } else {
+        /* mask unavailable */
+        soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
+    }
 }
index bdb8f712e32f87c1b3032741b38ef82c82ac3a85..2cf8582ec92e9ef52ebc50b446da9e2953309f9c 100644 (file)
 
 #include "common.hpp"
 
-void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0,
-    const ggml_tensor *src1, ggml_tensor *dst,
-    const float *src0_dd, const float *src1_dd,
-    float *dst_dd,
-    const queue_ptr &main_stream);
+void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);
 
 #endif // GGML_SYCL_SOFTMAX_HPP