case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOFT_MAX:
- // TODO: support attention sinks [TAG_ATTN_SINKS]
- return op->src[2] == nullptr;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return true;
GGML_ASSERT(src1->extra);
}
+ const ggml_tensor * src2 = dst->src[2];
+ if (src2) {
+ GGML_ASSERT(src2->extra);
+ }
+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
+ ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
+ cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
- CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
- CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
- CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
- CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
- CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
- CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
- CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
- CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
- CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
- CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
- CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
- CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
- CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
- CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13));
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &scale));
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &max_bias));
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &m0));
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m1));
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_head_log2));
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};