*dsti = *xi;
}
+static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
+ const sycl::half *xi = (const sycl::half *)cxi;
+ float *dsti = (float *)cdsti;
+
+ *dsti = *xi;
+}
+
static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
const int16_t *xi = (const int16_t *)cxi;
int16_t *dsti = (int16_t *)cdsti;
template <cpy_kernel_t cpy_1>
static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
- const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
- const int ne10, const int ne11, const int nb10, const int nb11, const int nb12,
- const sycl::nd_item<3> &item_ct1) {
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+ const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
// determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
// then combine those indices with the corresponding byte offsets to get the total offsets
- const int i02 = i / (ne00*ne01);
- const int i01 = (i - i02*ne01*ne00) / ne00;
- const int i00 = i - i02*ne01*ne00 - i01*ne00;
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
-
- const int i12 = i / (ne10*ne11);
- const int i11 = (i - i12*ne10*ne11) / ne10;
- const int i10 = i - i12*ne10*ne11 - i11*ne10;
- const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
+ const int i03 = i/(ne00 * ne01 * ne02);
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+ const int i13 = i/(ne10 * ne11 * ne12);
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
cpy_1(cx + x_offset, cdst + dst_offset);
}
template <cpy_kernel_t cpy_blck, int qk>
static void cpy_f32_q(const char * cx, char * cdst, const int ne,
- const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
- const int ne10, const int ne11, const int nb10, const int nb11, const int nb12,
- const sycl::nd_item<3> &item_ct1) {
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+ const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2)) *
qk;
return;
}
- const int i02 = i / (ne00*ne01);
- const int i01 = (i - i02*ne01*ne00) / ne00;
- const int i00 = (i - i02*ne01*ne00 - i01*ne00);
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
+ const int i03 = i/(ne00 * ne01 * ne02);
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
- const int i12 = i / (ne10*ne11);
- const int i11 = (i - i12*ne10*ne11) / ne10;
- const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
- const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
+ const int i13 = i/(ne10 * ne11 * ne12);
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
cpy_blck(cx + x_offset, cdst + dst_offset);
}
static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
const int ne00, const int ne01,
- const int nb00, const int nb01,
- const int nb02, const int ne10,
- const int ne11, const int nb10,
- const int nb11, const int nb12,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
dpct::queue_ptr stream) {
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
- cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, nb00, nb01,
- nb02, ne10, ne11, nb10, nb11, nb12,
+ cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1);
});
}
static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
const int ne00, const int ne01,
- const int nb00, const int nb01,
- const int nb02, const int ne10,
- const int ne11, const int nb10,
- const int nb11, const int nb12,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
dpct::queue_ptr stream) {
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
- cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, nb00, nb01,
- nb02, ne10, ne11, nb10, nb11, nb12,
+ cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1);
});
}
static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
const int ne00, const int ne01,
- const int nb00, const int nb01,
- const int nb02, const int ne10,
- const int ne11, const int nb10,
- const int nb11, const int nb12,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
dpct::queue_ptr stream) {
GGML_ASSERT(ne % QK8_0 == 0);
sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(
- cx, cdst, ne, ne00, ne01, nb00, nb01, nb02,
- ne10, ne11, nb10, nb11, nb12, item_ct1);
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+ item_ct1);
});
}
static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
const int ne00, const int ne01,
- const int nb00, const int nb01,
- const int nb02, const int ne10,
- const int ne11, const int nb10,
- const int nb11, const int nb12,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
dpct::queue_ptr stream) {
GGML_ASSERT(ne % QK4_0 == 0);
sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(
- cx, cdst, ne, ne00, ne01, nb00, nb01, nb02,
- ne10, ne11, nb10, nb11, nb12, item_ct1);
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+ item_ct1);
});
}
static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
const int ne00, const int ne01,
- const int nb00, const int nb01,
- const int nb02, const int ne10,
- const int ne11, const int nb10,
- const int nb11, const int nb12,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
dpct::queue_ptr stream) {
GGML_ASSERT(ne % QK4_1 == 0);
sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(
- cx, cdst, ne, ne00, ne01, nb00, nb01, nb02,
- ne10, ne11, nb10, nb11, nb12, item_ct1);
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+ item_ct1);
});
}
static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
const int ne00, const int ne01,
- const int nb00, const int nb01,
- const int nb02, const int ne10,
- const int ne11, const int nb10,
- const int nb11, const int nb12,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
dpct::queue_ptr stream) {
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
- cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, nb00, nb01,
- nb02, ne10, ne11, nb10, nb11, nb12,
+ cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1);
});
}
static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
const int ne00, const int ne01,
- const int nb00, const int nb01,
- const int nb02, const int ne10,
- const int ne11, const int nb10,
- const int nb11, const int nb12,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
dpct::queue_ptr stream) {
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
- cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, nb00, nb01,
- nb02, ne10, ne11, nb10, nb11, nb12,
+ cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1);
});
}
static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
const int ne00, const int ne01,
- const int nb00, const int nb01,
- const int nb02, const int ne10,
- const int ne11, const int nb10,
- const int nb11, const int nb12,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
dpct::queue_ptr stream) {
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
- cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, nb00, nb01,
- nb02, ne10, ne11, nb10, nb11, nb12,
+ cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1);
});
}
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
- GGML_ASSERT(src0->ne[3] == 1);
+ const int64_t ne02 = src0->ne[2];
+
const int64_t nb00 = src0->nb[0];
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];
+ const int64_t nb03 = src0->nb[3];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
- GGML_ASSERT(src1->ne[3] == 1);
+ const int64_t ne12 = src1->ne[2];
+
const int64_t nb10 = src1->nb[0];
const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2];
+ const int64_t nb13 = src1->nb[3];
SYCL_CHECK(ggml_sycl_set_device(g_main_device));
dpct::queue_ptr main_stream = g_syclStreams[g_main_device_index][0];
char * src1_ddc = (char *) src1_extra->data_device[g_main_device_index];
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
- ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+ ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
- ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+ ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
- ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+ ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
- ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+ ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
- ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+ ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
- ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+ ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
- ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+ ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
- ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+ ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));