ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
+ bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
+ bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
+
// Handle src0
src0_ptr = (const cuda_t *) src0->data;
s11 = ne10;
s12 = ne11*s11;
s13 = ne12*s12;
+
+ is_src1_cont_2 = true;
}
// Setup destination buffer
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+ if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
+ // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
+ const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
+ const int64_t smb = ne12 == 1 ? s13 : s12;
+
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
- alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
- src1_ptr, cu_data_type_b, s11, s12, // strideB
- beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
+ src1_ptr, cu_data_type_b, s11, smb, // strideB
+ beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
ne12*ne13,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const size_t type_size_src0 = ggml_type_size(src0->type);
const size_t type_size_src1 = ggml_type_size(src1->type);
+ bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
+ bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
+
// SRC1 strides
int64_t s11 = nb11 / type_size_src1;
int64_t s12 = nb12 / type_size_src1;
s11 = ne10;
s12 = ne11 * s11;
s13 = ne12 * s12;
+
+ is_src1_cont_2 = true;
}
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
else
#endif
{
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+ if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
+ // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
+ const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
+ const int64_t smb = ne12 == 1 ? s13 : s12;
+
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
- src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
- src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
+ src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
+ src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
} else {
const int ne23 = ne12 * ne13;