stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
- const ggml_tensor_extra_gpu *src0_extra =
- (const ggml_tensor_extra_gpu *)src0->extra;
- const ggml_tensor_extra_gpu *src1_extra =
- (const ggml_tensor_extra_gpu *)src1->extra;
- const ggml_tensor_extra_gpu *dst_extra =
- (const ggml_tensor_extra_gpu *)dst->extra;
-
- ggml_tensor_extra_gpu src0_row_extra;
- ggml_tensor_extra_gpu src1_row_extra;
- ggml_tensor_extra_gpu dst_row_extra;
-
ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
- src1_row.backend = GGML_BACKEND_TYPE_GPU;
- dst_row.backend = GGML_BACKEND_TYPE_GPU;
-
- src0_row.extra = &src0_row_extra;
- src1_row.extra = &src1_row_extra;
- dst_row.extra = &dst_row_extra;
-
- char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
- ? (char *)src0->data
- : (char *)src0_extra->data_device[ctx.device];
- char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
- ? (char *)src1->data
- : (char *)src1_extra->data_device[ctx.device];
- char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
- ? (char *)dst->data
- : (char *)dst_extra->data_device[ctx.device];
+ char *src0_original = (char *)src0->data;
+ char *src1_original = (char *)src1->data;
+ char *dst_original = (char *)dst->data;
src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
const int64_t i1 = id;
const int64_t i2 = i12;
- src0_row_extra.data_device[ctx.device] =
- src0_original + i02*nb02;
- src1_row_extra.data_device[ctx.device] =
- src1_original + + i11*nb11 + i12*nb12;
- dst_row_extra.data_device[ctx.device] =
- dst_original + i1*nb1 + i2*nb2;
+ src0_row.data = src0_original + i02*nb02;
+ src1_row.data = src1_original + + i11*nb11 + i12*nb12;
+ dst_row.data = dst_original + i1*nb1 + i2*nb2;
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
}
ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
- src1_row_extra.data_device[ctx.device] = src1_contiguous.get();
- dst_row_extra.data_device[ctx.device] = dst_contiguous.get();
+ src1_row.data = src1_contiguous.get();
+ dst_row.data = dst_contiguous.get();
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0;
});
}
- src0_row_extra.data_device[ctx.device] = src0_original + i02*nb02;
+ src0_row.data = src0_original + i02*nb02;
GGML_ASSERT(nb11 == sizeof(float)*ne10);
GGML_ASSERT(nb1 == sizeof(float)*ne0);
return false;
}
}
+ ggml_type src0_type = op->src[0]->type;
+ if (src0_type == GGML_TYPE_BF16) {
+ return false;
+ }
return true;
} break;
case GGML_OP_GET_ROWS: