return false;
} break;
case GGML_OP_SILU_BACK:
- return ggml_is_contiguous(op->src[0]);
+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
- return ggml_is_contiguous(op->src[0]);
+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
}
case GGML_OP_RMS_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD:
- case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
- case GGML_OP_CONCAT:
case GGML_OP_SILU_BACK:
case GGML_OP_RMS_NORM_BACK:
- case GGML_OP_UPSCALE:
- case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
+ return op->src[0]->type == GGML_TYPE_F32;
+ case GGML_OP_ACC:
+ case GGML_OP_CONCAT:
+ case GGML_OP_UPSCALE:
+ case GGML_OP_SCALE:
case GGML_OP_PAD:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
test_cases.emplace_back(new test_add1());
test_cases.emplace_back(new test_scale());
-
- for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
- test_cases.emplace_back(new test_silu_back());
- }
+ test_cases.emplace_back(new test_silu_back());
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
for (bool v : {false, true}) {