return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
#endif
case GGML_OP_NORM:
- return true;
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
- return ggml_is_contiguous(op->src[0]);
case GGML_OP_RMS_NORM:
- return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
+ return true;
case GGML_OP_RMS_NORM_BACK:
- return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
+ return ggml_is_contiguous(op->src[0]);
case GGML_OP_SCALE:
return true;
case GGML_OP_CONT:
const float eps, queue_ptr stream, int device) {
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
- GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
- GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
- GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);