dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
}
-static void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale,
- const sycl::nd_item<3> &item_ct1, float *buf) {
+
+template <bool vals_smem, int ncols_template, int block_size_template>
+static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
+ const int nrows_y, const float scale, const float max_bias, const float m0,
+ const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
+
const int tid = item_ct1.get_local_id(2);
const int rowx = item_ct1.get_group(2);
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
- const int block_size = item_ct1.get_local_range(2);
+ const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
+ float slope = 0.0f;
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const uint32_t h = rowx/nrows_y; // head index
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = sycl::pow(base, float(exp));
+ }
+
+ float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
float max_val = -INFINITY;
- for (int col = tid; col < ncols; col += block_size) {
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
+ const int col = col0 + tid;
+
+ if (ncols_template == 0 && col >= ncols) {
+ break;
+ }
+
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
- max_val = sycl::max(max_val, x[ix] * scale + (y ? y[iy] : 0.0f));
+
+ const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
+
+ vals[col] = val;
+ max_val = sycl::max(max_val, val);
}
// find the max value in the block
if (warp_id == 0) {
buf[lane_id] = -INFINITY;
}
- /*
- DPCT1118:12: SYCL group functions and algorithms must be encountered in
- converged control flow. You may need to adjust the code.
- */
- /*
- DPCT1065:60: Consider replacing sycl::nd_item::barrier() with
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
- better performance if there is no access to global memory.
- */
- item_ct1.barrier();
+ item_ct1.barrier(sycl::access::fence_space::local_space);
if (lane_id == 0) {
buf[warp_id] = max_val;
}
- /*
- DPCT1118:13: SYCL group functions and algorithms must be encountered in
- converged control flow. You may need to adjust the code.
- */
- /*
- DPCT1065:61: Consider replacing sycl::nd_item::barrier() with
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
- better performance if there is no access to global memory.
- */
- item_ct1.barrier();
+ item_ct1.barrier(sycl::access::fence_space::local_space);
max_val = buf[lane_id];
max_val = warp_reduce_max(max_val, item_ct1);
float tmp = 0.f;
- for (int col = tid; col < ncols; col += block_size) {
- const int ix = rowx*ncols + col;
- const int iy = rowy*ncols + col;
- const float val =
- sycl::native::exp((x[ix] * scale + (y ? y[iy] : 0.0f)) - max_val);
+#pragma unroll
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
+ const int col = col0 + tid;
+ if (ncols_template == 0 && col >= ncols) {
+ break;
+ }
+
+ const float val = sycl::native::exp(vals[col] - max_val);
tmp += val;
- dst[ix] = val;
+ vals[col] = val;
}
// find the sum of exps in the block
if (warp_id == 0) {
buf[lane_id] = 0.f;
}
- /*
- DPCT1118:14: SYCL group functions and algorithms must be encountered in
- converged control flow. You may need to adjust the code.
- */
- /*
- DPCT1065:62: Consider replacing sycl::nd_item::barrier() with
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
- better performance if there is no access to global memory.
- */
- item_ct1.barrier();
+ item_ct1.barrier(sycl::access::fence_space::local_space);
if (lane_id == 0) {
buf[warp_id] = tmp;
}
- /*
- DPCT1118:15: SYCL group functions and algorithms must be encountered in
- converged control flow. You may need to adjust the code.
- */
- /*
- DPCT1065:63: Consider replacing sycl::nd_item::barrier() with
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
- better performance if there is no access to global memory.
- */
- item_ct1.barrier();
+ item_ct1.barrier(sycl::access::fence_space::local_space);
tmp = buf[lane_id];
tmp = warp_reduce_sum(tmp, item_ct1);
}
- const float inv_tmp = 1.f / tmp;
+ const float inv_sum = 1.f / tmp;
- for (int col = tid; col < ncols; col += block_size) {
- const int i = rowx*ncols + col;
- dst[i] *= inv_tmp;
+#pragma unroll
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
+ const int col = col0 + tid;
+
+ if (ncols_template == 0 && col >= ncols) {
+ return;
+ }
+
+ const int idst = rowx*ncols + col;
+ dst[idst] = vals[col] * inv_sum;
}
}
});
}
-static void soft_max_f32_sycl(const float *x, const float *y, float *dst,
- const int ncols_x, const int nrows_x,
- const int nrows_y, const float scale,
- dpct::queue_ptr stream) {
- int nth = WARP_SIZE;
- while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
- const sycl::range<3> block_dims(1, 1, nth);
- const sycl::range<3> block_nums(1, 1, nrows_x);
- /*
- DPCT1049:46: The work-group size passed to the SYCL kernel may exceed the
- limit. To get the device limit, query info::device::max_work_group_size.
- Adjust the work-group size if needed.
- */
+template <bool vals_smem, int ncols_template, int block_size_template>
+static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
+ const int nrows_y, const float scale, const float max_bias, const float m0,
+ const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
+ const size_t n_local_scratch, dpct::queue_ptr stream) {
stream->submit([&](sycl::handler &cgh) {
- /*
- DPCT1101:96: 'SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE' expression was
- replaced with a value. Modify the code to use the original expression,
- provided in comments, if it is correct.
- */
- sycl::local_accessor<float, 1> buf_acc_ct1(
- sycl::range<1>(32 /*SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE*/), cgh);
+ sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
- soft_max_f32(x, y, dst, ncols_x, nrows_y, scale, item_ct1,
- buf_acc_ct1.get_pointer());
+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
+ nrows_y, scale, max_bias, m0,
+ m1, n_head_log2, item_ct1,
+ local_buf_acc.get_pointer());
});
});
}
+static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
+ float * dst, const int ncols_x, const int nrows_x,
+ const int nrows_y, const float scale, const float max_bias,
+ dpct::queue_ptr stream) {
+ int nth = WARP_SIZE;
+ while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
+ const sycl::range<3> block_dims(1, 1, nth);
+ const sycl::range<3> block_nums(1, 1, nrows_x);
+ const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
+ static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
+
+ const uint32_t n_head_kv = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
+ if (n_local_scratch*sizeof(float) < local_mem_size) {
+ switch (ncols_x) {
+ case 32:
+ soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 64:
+ soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 128:
+ soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 256:
+ soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 512:
+ soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 1024:
+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 2048:
+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 4096:
+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ default:
+ soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ }
+ } else {
+ soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, WARP_SIZE, stream);
+ }
+}
+
template <typename T>
static void im2col_sycl(const float *x, T *dst, int IW, int IH,
int OW, int OH, int KW, int KH, int IC,
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
- const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
+ const int64_t nrows_y = src0->ne[1];
float scale = 1.0f;
- memcpy(&scale, dst->op_params, sizeof(float));
+ float max_bias = 0.0f;
- soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
+ memcpy(&scale, dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, dst->op_params + 1, sizeof(float));
- (void) dst;
+ // positions tensor
+ float * src2_dd = nullptr;
+ sycl_pool_alloc<float> src2_f;
+
+ ggml_tensor * src2 = dst->src[2];
+ const bool use_src2 = src2 != nullptr;
+
+ if (use_src2) {
+ const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
+
+ if (src2_on_device) {
+ ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
+ src2_dd = (float *) src2_extra->data_device[g_main_device];
+ } else {
+ src2_dd = src2_f.alloc(ggml_nelements(src2));
+ SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
+ }
+ }
+
+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
+ nrows_x, nrows_y, scale, max_bias, main_stream);
}
inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1,