const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
- if (i0 >= n_dims) {
- const int i = row_dst*ne0 + i0;
-
- dst[i + 0] = x[i + 0];
- dst[i + 1] = x[i + 1];
-
- return;
- }
-
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0;
+ if (i0 >= n_dims) {
+ dst[idst + 0] = x[ix + 0];
+ dst[idst + 1] = x[ix + 1];
+
+ return;
+ }
+
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
- if (i0 >= n_dims) {
- const int i = row_dst*ne0 + i0;
-
- dst[i + 0] = x[i + 0];
- dst[i + 1] = x[i + 1];
-
- return;
- }
-
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
+ if (i0 >= n_dims) {
+ dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
+ dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
+
+ return;
+ }
+
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
- if (i0 >= n_dims) {
- const int i = row_dst*ne0 + i0;
-
- dst[i + 0] = x[i + 0];
- dst[i + 1] = x[i + 1];
-
- return;
- }
-
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
+ if (i0 >= n_dims) {
+ dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
+ dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
+
+ return;
+ }
+
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
- if (i0 >= n_dims) {
- const int i = row * ne0 + i0;
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
- return;
- }
-
const int row0 = row % ne1;
const int channel0 = row / ne1;
const int i = row * ne0 + i0;
const int i2 = channel0 * s2 + row0 * s1 + i0;
+ if (i0 >= n_dims) {
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
+ return;
+ }
+
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
- if (i0 >= n_dims) {
- const int i = row * ne0 + i0;
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
- return;
- }
-
const int row0 = row % ne1;
const int channel0 = row / ne1;
const int i = row * ne0 + i0 / 2;
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
+ if (i0 >= n_dims) {
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
+ return;
+ }
+
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
}
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
- if (i0 >= n_dims) {
- const int i = row_dst*ne0 + i0;
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
- return;
- }
-
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = (row_dst * ne0) + (i0 / 2);
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
+ if (i0 >= n_dims) {
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
+ return;
+ }
+
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
for (bool fw : {true, false}) { // fw == forward
bool all = true;
- for (float v : { 0, 1 }) {
- for (float fs : { 1.0f, 1.4245f }) {
- for (float ef : { 0.0f, 0.7465f }) {
- for (float af : { 1.0f, 1.4245f }) {
- for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
- for (bool ff : {false, true}) { // freq_factors
+ for (float fs : { 1.0f, 1.4245f }) {
+ for (float ef : { 0.0f, 0.7465f }) {
+ for (float af : { 1.0f, 1.4245f }) {
+ for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+ for (bool ff : {false, true}) { // freq_factors
+ for (float v : { 0, 1 }) {
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
if (all) {
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
+
+ test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw));
+ test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
+ test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
+
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
+ test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
}
if (all) {
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
+ test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
+ test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
}