cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
- const int ne00 = src0 ? src0->ne[0] : 0;
- const int ne01 = src0 ? src0->ne[1] : 0;
- const int ne02 = src0 ? src0->ne[2] : 0;
- const int ne03 = src0 ? src0->ne[3] : 0;
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const cl_long nb01 = src0->nb[1];
+ const cl_long nb02 = src0->nb[2];
+ const cl_long nb03 = src0->nb[3];
+
+ const int ne12 = src1 ? src1->ne[2] : 0;
+ const int ne13 = src1 ? src1->ne[3] : 0;
+
+ const cl_long nb11 = src1 ? src1->nb[1] : 0;
+ const cl_long nb12 = src1 ? src1->nb[2] : 0;
+ const cl_long nb13 = src1 ? src1->nb[3] : 0;
+
+ const cl_long nb1 = dst->nb[1];
+ const cl_long nb2 = dst->nb[2];
+ const cl_long nb3 = dst->nb[3];
float scale, max_bias;
memcpy(&scale, dst->op_params + 0, sizeof(float));
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
- const int nrows_x = ggml_nrows(src0);
- const int nrows_y = src0->ne[1];
-
- const int n_head = nrows_x/nrows_y;
+ const int n_head = src0->ne[2];
const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
- CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale));
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias));
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0));
- CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1));
- CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2));
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};