]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Add support for soft_max ALiBi (llama/5639)
authorAidanBeltonS <redacted>
Mon, 26 Feb 2024 14:02:11 +0000 (14:02 +0000)
committerGeorgi Gerganov <redacted>
Wed, 28 Feb 2024 11:00:29 +0000 (13:00 +0200)
* Add support for bias

* Update pre-processor

* rm commented code

* fix format

* fix CI

---------

Co-authored-by: Abhilash Majumder <redacted>
ggml-sycl.cpp

index e1a02f24653db1c9ca0b629df36edcbca9934cb6..a054ec8b92bac708fb4a3a75e9bda578fb24dd8c 100644 (file)
@@ -8126,23 +8126,51 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
     dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
 }
 
-static void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale,
-                         const sycl::nd_item<3> &item_ct1, float *buf) {
+
+template <bool vals_smem, int ncols_template, int block_size_template>
+static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
+                         const int nrows_y, const float scale, const float max_bias, const float m0,
+                         const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
+    const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
+
     const int tid = item_ct1.get_local_id(2);
     const int rowx = item_ct1.get_group(2);
     const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
 
-    const int block_size = item_ct1.get_local_range(2);
+    const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
 
     const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
     const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
 
+    float slope = 0.0f;
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        const uint32_t h = rowx/nrows_y; // head index
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = sycl::pow(base, float(exp));
+    }
+
+    float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
     float max_val = -INFINITY;
 
-    for (int col = tid; col < ncols; col += block_size) {
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            break;
+        }
+
         const int ix = rowx*ncols + col;
         const int iy = rowy*ncols + col;
-        max_val = sycl::max(max_val, x[ix] * scale + (y ? y[iy] : 0.0f));
+
+        const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
+
+        vals[col] = val;
+        max_val = sycl::max(max_val, val);
     }
 
     // find the max value in the block
@@ -8151,30 +8179,12 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
         if (warp_id == 0) {
             buf[lane_id] = -INFINITY;
         }
-        /*
-        DPCT1118:12: SYCL group functions and algorithms must be encountered in
-        converged control flow. You may need to adjust the code.
-        */
-        /*
-        DPCT1065:60: Consider replacing sycl::nd_item::barrier() with
-        sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
-        better performance if there is no access to global memory.
-        */
-        item_ct1.barrier();
+        item_ct1.barrier(sycl::access::fence_space::local_space);
 
         if (lane_id == 0) {
             buf[warp_id] = max_val;
         }
-        /*
-        DPCT1118:13: SYCL group functions and algorithms must be encountered in
-        converged control flow. You may need to adjust the code.
-        */
-        /*
-        DPCT1065:61: Consider replacing sycl::nd_item::barrier() with
-        sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
-        better performance if there is no access to global memory.
-        */
-        item_ct1.barrier();
+        item_ct1.barrier(sycl::access::fence_space::local_space);
 
         max_val = buf[lane_id];
         max_val = warp_reduce_max(max_val, item_ct1);
@@ -8182,13 +8192,16 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
 
     float tmp = 0.f;
 
-    for (int col = tid; col < ncols; col += block_size) {
-        const int ix = rowx*ncols + col;
-        const int iy = rowy*ncols + col;
-        const float val =
-            sycl::native::exp((x[ix] * scale + (y ? y[iy] : 0.0f)) - max_val);
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+                if (ncols_template == 0 && col >= ncols) {
+            break;
+        }
+
+        const float val = sycl::native::exp(vals[col] - max_val);
         tmp += val;
-        dst[ix] = val;
+        vals[col] = val;
     }
 
     // find the sum of exps in the block
@@ -8197,40 +8210,29 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
         if (warp_id == 0) {
             buf[lane_id] = 0.f;
         }
-        /*
-        DPCT1118:14: SYCL group functions and algorithms must be encountered in
-        converged control flow. You may need to adjust the code.
-        */
-        /*
-        DPCT1065:62: Consider replacing sycl::nd_item::barrier() with
-        sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
-        better performance if there is no access to global memory.
-        */
-        item_ct1.barrier();
+        item_ct1.barrier(sycl::access::fence_space::local_space);
 
         if (lane_id == 0) {
             buf[warp_id] = tmp;
         }
-        /*
-        DPCT1118:15: SYCL group functions and algorithms must be encountered in
-        converged control flow. You may need to adjust the code.
-        */
-        /*
-        DPCT1065:63: Consider replacing sycl::nd_item::barrier() with
-        sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
-        better performance if there is no access to global memory.
-        */
-        item_ct1.barrier();
+        item_ct1.barrier(sycl::access::fence_space::local_space);
 
         tmp = buf[lane_id];
         tmp = warp_reduce_sum(tmp, item_ct1);
     }
 
-    const float inv_tmp = 1.f / tmp;
+    const float inv_sum = 1.f / tmp;
 
-    for (int col = tid; col < ncols; col += block_size) {
-        const int i = rowx*ncols + col;
-        dst[i] *= inv_tmp;
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            return;
+        }
+
+        const int idst = rowx*ncols + col;
+        dst[idst] = vals[col] * inv_sum;
     }
 }
 
@@ -10867,37 +10869,98 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
                          });
 }
 
-static void soft_max_f32_sycl(const float *x, const float *y, float *dst,
-                              const int ncols_x, const int nrows_x,
-                              const int nrows_y, const float scale,
-                              dpct::queue_ptr stream) {
-    int nth = WARP_SIZE;
-    while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
-    const sycl::range<3> block_dims(1, 1, nth);
-    const sycl::range<3> block_nums(1, 1, nrows_x);
-    /*
-    DPCT1049:46: The work-group size passed to the SYCL kernel may exceed the
-    limit. To get the device limit, query info::device::max_work_group_size.
-    Adjust the work-group size if needed.
-    */
+template <bool vals_smem, int ncols_template, int block_size_template>
+static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
+                                   const int nrows_y, const float scale, const float max_bias, const float m0,
+                                   const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
+                                   const size_t n_local_scratch, dpct::queue_ptr stream) {
     stream->submit([&](sycl::handler &cgh) {
-        /*
-        DPCT1101:96: 'SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE' expression was
-        replaced with a value. Modify the code to use the original expression,
-        provided in comments, if it is correct.
-        */
-        sycl::local_accessor<float, 1> buf_acc_ct1(
-            sycl::range<1>(32 /*SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE*/), cgh);
+        sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
 
         cgh.parallel_for(
             sycl::nd_range<3>(block_nums * block_dims, block_dims),
             [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-                soft_max_f32(x, y, dst, ncols_x, nrows_y, scale, item_ct1,
-                             buf_acc_ct1.get_pointer());
+                soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
+                                                                             nrows_y, scale, max_bias, m0,
+                                                                             m1, n_head_log2, item_ct1,
+                                                                             local_buf_acc.get_pointer());
             });
     });
 }
 
+static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
+                              float * dst, const int ncols_x, const int nrows_x,
+                              const int nrows_y, const float scale, const float max_bias,
+                              dpct::queue_ptr stream) {
+    int nth = WARP_SIZE;
+    while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
+    const sycl::range<3> block_dims(1, 1, nth);
+    const sycl::range<3> block_nums(1, 1, nrows_x);
+    const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
+    static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
+
+    const uint32_t n_head_kv   = nrows_x/nrows_y;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
+    if (n_local_scratch*sizeof(float) < local_mem_size) {
+        switch (ncols_x) {
+            case 32:
+                soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                                                     max_bias, m0, m1, n_head_log2, block_nums,
+                                                     block_dims, n_local_scratch, stream);
+                break;
+            case 64:
+                soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                                                     max_bias, m0, m1, n_head_log2, block_nums,
+                                                     block_dims, n_local_scratch, stream);
+                break;
+            case 128:
+                soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                                                       max_bias, m0, m1, n_head_log2, block_nums,
+                                                       block_dims, n_local_scratch, stream);
+                break;
+            case 256:
+                soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                                                       max_bias, m0, m1, n_head_log2, block_nums,
+                                                       block_dims, n_local_scratch, stream);
+                break;
+            case 512:
+                soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                                                       max_bias, m0, m1, n_head_log2, block_nums,
+                                                       block_dims, n_local_scratch, stream);
+                break;
+            case 1024:
+                soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                                                         max_bias, m0, m1, n_head_log2, block_nums,
+                                                         block_dims, n_local_scratch, stream);
+                break;
+            case 2048:
+                soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                                                         max_bias, m0, m1, n_head_log2, block_nums,
+                                                         block_dims, n_local_scratch, stream);
+                break;
+            case 4096:
+                soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                                                         max_bias, m0, m1, n_head_log2, block_nums,
+                                                         block_dims, n_local_scratch, stream);
+                break;
+            default:
+                soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                                                   max_bias, m0, m1, n_head_log2, block_nums,
+                                                   block_dims, n_local_scratch, stream);
+                break;
+        }
+    } else {
+        soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                                            max_bias, m0, m1, n_head_log2, block_nums,
+                                            block_dims, WARP_SIZE, stream);
+    }
+}
+
 template <typename T>
 static void im2col_sycl(const float *x, T *dst, int IW, int IH,
                                 int OW, int OH, int KW, int KH, int IC,
@@ -12435,14 +12498,35 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
 
     const int64_t ne00 = src0->ne[0];
     const int64_t nrows_x = ggml_nrows(src0);
-    const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
+    const int64_t nrows_y = src0->ne[1];
 
     float scale = 1.0f;
-    memcpy(&scale, dst->op_params, sizeof(float));
+    float max_bias = 0.0f;
 
-    soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
+    memcpy(&scale, dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, dst->op_params + 1, sizeof(float));
 
-    (void) dst;
+    // positions tensor
+    float * src2_dd = nullptr;
+    sycl_pool_alloc<float> src2_f;
+
+    ggml_tensor * src2 = dst->src[2];
+    const bool use_src2 = src2 != nullptr;
+
+    if (use_src2) {
+        const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
+
+        if (src2_on_device) {
+            ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
+            src2_dd = (float *) src2_extra->data_device[g_main_device];
+        } else {
+            src2_dd = src2_f.alloc(ggml_nelements(src2));
+            SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
+        }
+    }
+
+    soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
+                      nrows_x, nrows_y, scale, max_bias, main_stream);
 }
 
 inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1,