case GGML_OP_SCALE:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD:
+ if (op->type == GGML_TYPE_F16) {
+ const bool src0_ok = op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32;
+ const bool src1_ok = op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32;
+ if (src0_ok && src1_ok) {
+ return true;
+ }
+ }
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SUB:
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
- GGML_ASSERT(src0->type == src1->type);
- GGML_ASSERT(src0->type == dst->type);
- GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
-
- const int ne00 = src0->ne[0];
- const int ne01 = src0->ne[1];
- const int ne02 = src0->ne[2];
- const int ne03 = src0->ne[3];
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
- const int ne10 = src1->ne[0];
- const int ne11 = src1->ne[1];
- const int ne12 = src1->ne[2];
- const int ne13 = src1->ne[3]; UNUSED(ne13);
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ const int ne12 = src1->ne[2];
+ const int ne13 = src1->ne[3];
const cl_ulong nb10 = src1->nb[0];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
- const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
+ const cl_ulong nb13 = src1->nb[3];
- const int ne0 = dst->ne[0];
- const int ne1 = dst->ne[1];
- const int ne2 = dst->ne[2];
- const int ne3 = dst->ne[3];
+ const int ne0 = dst->ne[0];
+ const int ne1 = dst->ne[1];
+ const int ne2 = dst->ne[2];
+ const int ne3 = dst->ne[3];
const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
- bool bcast_row = false;
cl_kernel kernel;
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
- GGML_ASSERT(ggml_is_contiguous(src0));
+ const bool bcast_row = ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0;
- // src1 is a row
+ if (bcast_row) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ne11 == 1);
+ }
- bcast_row = true;
- int ne = ne00 / 4;
-
- if (src0->type == GGML_TYPE_F32) {
+ if (dst->type == GGML_TYPE_F32) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);
+ if (bcast_row) {
kernel = backend_ctx->kernel_add_row;
+ const int ne = ne00 / 4;
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
} else {
- kernel = backend_ctx->kernel_add_row_f16;
- }
-
- CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
- CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
- CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
- CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
- CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
- CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
- } else {
- if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_add;
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
+ CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
+ CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
+ CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
+ const int type_src0 = (src0->type == GGML_TYPE_F32);
+ const int type_src1 = (src1->type == GGML_TYPE_F32);
+ if (bcast_row) {
+ kernel = backend_ctx->kernel_add_row_f16;
+ const int ne = ne00 / 4;
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &type_src0));
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &type_src1));
} else {
kernel = backend_ctx->kernel_add_f16;
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
+ CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
+ CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
+ CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+ CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &type_src0));
+ CL_CHECK(clSetKernelArg(kernel, 31, sizeof(int), &type_src1));
}
-
- CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
- CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
- CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
- CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
- CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
- CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
- CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
- CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
- CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
- CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
- CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
- CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
- CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
- CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
- CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
- CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
- CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
- CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
- CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
- CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
- CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
- CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
- CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
- CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
- CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+ } else {
+ GGML_ASSERT(false && "unsupported data types for add");
}
if (bcast_row) {
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
- local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
+ local_work_size_ptr = nullptr;
}
- backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
+ backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size_ptr, dst);
} else {
unsigned int nth = MIN(64, ne0);
- size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
+ size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);