// ggml_compute_forward_concat
+static void ggml_compute_forward_concat_any(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ const size_t len = ggml_type_size(src0->type);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+ GGML_ASSERT(dim >= 0 && dim < 4);
+
+ int64_t o[4] = {0, 0, 0, 0};
+ o[dim] = src0->ne[dim];
+
+ const char * x;
+
+ // TODO: smarter multi-theading
+ for (int i3 = 0; i3 < ne3; i3++) {
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
+ for (int i1 = 0; i1 < ne1; i1++) {
+ for (int i0 = 0; i0 < ne0; i0++) {
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
+ } else {
+ x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
+ }
+
+ char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
+
+ memcpy(y, x, len);
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_concat_i8(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+ GGML_ASSERT(dim >= 0 && dim < 4);
+
+ int64_t o[4] = {0, 0, 0, 0};
+ o[dim] = src0->ne[dim];
+
+ const int8_t * x;
+
+ // TODO: smarter multi-theading
+ for (int i3 = 0; i3 < ne3; i3++) {
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
+ for (int i1 = 0; i1 < ne1; i1++) {
+ for (int i0 = 0; i0 < ne0; i0++) {
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
+ } else {
+ x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
+ }
+
+ int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+ *y = *x;
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_concat_f16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+ GGML_ASSERT(dim >= 0 && dim < 4);
+
+ int64_t o[4] = {0, 0, 0, 0};
+ o[dim] = src0->ne[dim];
+
+ const ggml_fp16_t * x;
+
+ // TODO: smarter multi-theading
+ for (int i3 = 0; i3 < ne3; i3++) {
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
+ for (int i1 = 0; i1 < ne1; i1++) {
+ for (int i0 = 0; i0 < ne0; i0++) {
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
+ } else {
+ x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
+ }
+
+ ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+ *y = *x;
+ }
+ }
+ }
+ }
+}
+
static void ggml_compute_forward_concat_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
- GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_BF16:
+ case GGML_TYPE_I16:
+ {
+ ggml_compute_forward_concat_f16(params, dst);
+ } break;
+ case GGML_TYPE_I8:
+ {
+ ggml_compute_forward_concat_i8(params, dst);
+ } break;
case GGML_TYPE_F32:
case GGML_TYPE_I32:
{
} break;
default:
{
- GGML_ABORT("fatal error");
+ ggml_compute_forward_concat_any(params, dst);
}
}
}