const int32_t* src2_d = (const int32_t*)src2->data;
float* dst_d = (float*)dst->data;
- int threads = std::min((int)ne00, 768); // cols
+ const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
+ assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
+
+ int threads = std::min((unsigned int)ne00, max_work_group_size); // cols
+
ctx.stream()->parallel_for(
sycl::nd_range<3>(
sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads),