From: Georgi Gerganov Date: Tue, 5 Dec 2023 13:17:48 +0000 (+0200) Subject: metal : check supported ops at runtime (#632) X-Git-Tag: upstream/0.0.1642~1182 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=33d2fda74e98523154b0ddc47762d248c3d38eaa;p=pkg%2Fggml%2Fsources%2Fggml metal : check supported ops at runtime (#632) * metal : check supported ops at runtime * metal : remove TODOs --- diff --git a/src/ggml-metal.m b/src/ggml-metal.m index f2267356..cff9d5bc 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -181,8 +181,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ } } - - struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_LOG_INFO("%s: allocating\n", __func__); @@ -773,6 +771,70 @@ void ggml_metal_graph_find_concurrency( } } +static bool ggml_metal_supports_op(const struct ggml_tensor * op) { + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_GELU: + return true; + default: + return false; + } + break; + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + case GGML_OP_CONCAT: + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SUM_ROWS: + case GGML_OP_SOFT_MAX: + case GGML_OP_RMS_NORM: + case GGML_OP_NORM: + case GGML_OP_ALIBI: + case GGML_OP_ROPE: + case GGML_OP_IM2COL: + case GGML_OP_ARGSORT: + case GGML_OP_DUP: + case GGML_OP_CPY: + case GGML_OP_CONT: + return true; + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_GET_ROWS: + { + return op->ne[0] % 4 == 0; + } break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + { + struct ggml_tensor * a; + struct ggml_tensor * b; UNUSED(b); + if (op->op == GGML_OP_MUL_MAT) { + a = op->src[0]; + b = op->src[1]; + } else { + a = op->src[2]; + b = op->src[1]; + } + if (a->ne[3] != 1) { + return false; + } + if (ggml_is_quantized(a->type) && a->ne[2] != 1) { + return false; + } + return true; + } break; + default: + return false; + } +} void ggml_metal_graph_compute( struct ggml_metal_context * ctx, struct ggml_cgraph * gf) { @@ -843,6 +905,8 @@ void ggml_metal_graph_compute( } break; } + GGML_ASSERT(ggml_metal_supports_op(dst)); + const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0; const int64_t ne02 = src0 ? src0->ne[2] : 0; @@ -1973,70 +2037,7 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml } static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - switch (op->op) { - case GGML_OP_UNARY: - switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_GELU: - return true; - default: - return false; - } - break; - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - case GGML_OP_CONCAT: - case GGML_OP_ADD: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_SCALE: - case GGML_OP_SQR: - case GGML_OP_SUM_ROWS: - case GGML_OP_SOFT_MAX: - case GGML_OP_RMS_NORM: - case GGML_OP_NORM: - case GGML_OP_ALIBI: - case GGML_OP_ROPE: - case GGML_OP_IM2COL: - case GGML_OP_ARGSORT: - case GGML_OP_DUP: - case GGML_OP_CPY: - case GGML_OP_CONT: - return true; - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_GET_ROWS: - { - // TODO: also check during graph_compute - return op->ne[0] % 4 == 0; - } break; - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - { - // TODO: also check during graph_compute - struct ggml_tensor * a; - struct ggml_tensor * b; UNUSED(b); - if (op->op == GGML_OP_MUL_MAT) { - a = op->src[0]; - b = op->src[1]; - } else { - a = op->src[2]; - b = op->src[1]; - } - if (a->ne[3] != 1) { - return false; - } - if (ggml_is_quantized(a->type) && a->ne[2] != 1) { - return false; - } - return true; - } break; - default: - return false; - } + return ggml_metal_supports_op(op); UNUSED(backend); }