}
}
-
-
struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
}
}
+static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
+ switch (op->op) {
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(op)) {
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_GELU:
+ return true;
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_CONCAT:
+ case GGML_OP_ADD:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_SCALE:
+ case GGML_OP_SQR:
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_NORM:
+ case GGML_OP_ALIBI:
+ case GGML_OP_ROPE:
+ case GGML_OP_IM2COL:
+ case GGML_OP_ARGSORT:
+ case GGML_OP_DUP:
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ return true;
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_GET_ROWS:
+ {
+ return op->ne[0] % 4 == 0;
+ } break;
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ {
+ struct ggml_tensor * a;
+ struct ggml_tensor * b; UNUSED(b);
+ if (op->op == GGML_OP_MUL_MAT) {
+ a = op->src[0];
+ b = op->src[1];
+ } else {
+ a = op->src[2];
+ b = op->src[1];
+ }
+ if (a->ne[3] != 1) {
+ return false;
+ }
+ if (ggml_is_quantized(a->type) && a->ne[2] != 1) {
+ return false;
+ }
+ return true;
+ } break;
+ default:
+ return false;
+ }
+}
void ggml_metal_graph_compute(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
} break;
}
+ GGML_ASSERT(ggml_metal_supports_op(dst));
+
const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
const int64_t ne02 = src0 ? src0->ne[2] : 0;
}
static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
- switch (op->op) {
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(op)) {
- case GGML_UNARY_OP_SILU:
- case GGML_UNARY_OP_RELU:
- case GGML_UNARY_OP_GELU:
- return true;
- default:
- return false;
- }
- break;
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_PERMUTE:
- case GGML_OP_CONCAT:
- case GGML_OP_ADD:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- case GGML_OP_SCALE:
- case GGML_OP_SQR:
- case GGML_OP_SUM_ROWS:
- case GGML_OP_SOFT_MAX:
- case GGML_OP_RMS_NORM:
- case GGML_OP_NORM:
- case GGML_OP_ALIBI:
- case GGML_OP_ROPE:
- case GGML_OP_IM2COL:
- case GGML_OP_ARGSORT:
- case GGML_OP_DUP:
- case GGML_OP_CPY:
- case GGML_OP_CONT:
- return true;
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_GET_ROWS:
- {
- // TODO: also check during graph_compute
- return op->ne[0] % 4 == 0;
- } break;
- case GGML_OP_MUL_MAT:
- case GGML_OP_MUL_MAT_ID:
- {
- // TODO: also check during graph_compute
- struct ggml_tensor * a;
- struct ggml_tensor * b; UNUSED(b);
- if (op->op == GGML_OP_MUL_MAT) {
- a = op->src[0];
- b = op->src[1];
- } else {
- a = op->src[2];
- b = op->src[1];
- }
- if (a->ne[3] != 1) {
- return false;
- }
- if (ggml_is_quantized(a->type) && a->ne[2] != 1) {
- return false;
- }
- return true;
- } break;
- default:
- return false;
- }
+ return ggml_metal_supports_op(op);
UNUSED(backend);
}