const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
- if (mmp == nullptr) {
+ if (qx_needs_dequant) {
// Fall back to dequant + f16 mulmat
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
}
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
- if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1) {
+ if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
+ // detect 0213 permutation, and batch size of 1
+ src0->nb[0] <= src0->nb[2] &&
+ src0->nb[2] <= src0->nb[1] &&
+ src0->nb[1] <= src0->nb[3] &&
+ src1->nb[0] <= src1->nb[2] &&
+ src1->nb[2] <= src1->nb[1] &&
+ src1->nb[1] <= src1->nb[3] &&
+ src0->ne[3] == 1 &&
+ src1->ne[3] == 1) {
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
- } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1) {
+ } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
+ !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
} else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
- if (mmp == nullptr) {
+ if (qx_needs_dequant) {
GGML_ABORT("fatal error");
}
const uint32_t OH = is_2D ? dst->ne[2] : 1;
const uint32_t OW = dst->ne[1];
- const uint32_t batch = src1->ne[3];
+ const uint32_t batch = src1->ne[is_2D ? 3 : 2];
elements = { OW * KW * KH, OH, batch * IC };
} break;
const uint32_t OW = dst->ne[1];
const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
- const uint32_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
+ const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const uint32_t pelements = OW * KW * KH;
if (a->ne[3] != b->ne[3]) {
return false;
}
+ if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
+ !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
+ return false;
+ }
+
return true;
} break;
case GGML_OP_GET_ROWS: