case GGML_OP_MUL_MAT_ID:
{
struct ggml_tensor * a = op->src[0];
- if (op->op == GGML_OP_MUL_MAT) {
- struct ggml_tensor * b = op->src[1];
- if (a->ne[3] != b->ne[3]) {
- return false;
- }
+ struct ggml_tensor * b = op->src[1];
+ if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
+ return false;
+ }
+ if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
+ return false;
}
switch (a->type) {
case GGML_TYPE_F32:
return true;
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
+ return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
#else
if (op->src[0]->ne[0] == 128) {
return true;