case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
- struct ggml_tensor * a;
- struct ggml_tensor * b;
+ struct ggml_tensor * a = op->src[0];
if (op->op == GGML_OP_MUL_MAT) {
- a = op->src[0];
- b = op->src[1];
- } else {
- a = op->src[2];
- b = op->src[1];
- }
- if (a->ne[3] != b->ne[3]) {
- return false;
- }
- ggml_type a_type = a->type;
- if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
- a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
- a_type == GGML_TYPE_IQ1_M || a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
- if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
+ struct ggml_tensor * b = op->src[1];
+ if (a->ne[3] != b->ne[3]) {
return false;
}
}
- return true;
+ switch (a->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_Q8_K:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ return true;
+ default:
+ return false;
+ }
} break;
case GGML_OP_GET_ROWS:
{
GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+ GGML_TYPE_BF16,
};
// unary ops