]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : check supported ops at runtime (#632)
authorGeorgi Gerganov <redacted>
Tue, 5 Dec 2023 13:17:48 +0000 (15:17 +0200)
committerGitHub <redacted>
Tue, 5 Dec 2023 13:17:48 +0000 (15:17 +0200)
* metal : check supported ops at runtime

* metal : remove TODOs

src/ggml-metal.m

index f2267356cc0a0466137ddb42c8a246cc619cb6ca..cff9d5bc68ed5ca6aceaf69e63be99982da25a6b 100644 (file)
@@ -181,8 +181,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
     }
 }
 
-
-
 struct ggml_metal_context * ggml_metal_init(int n_cb) {
     GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
 
@@ -773,6 +771,70 @@ void ggml_metal_graph_find_concurrency(
     }
 }
 
+static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
+    switch (op->op) {
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_GELU:
+                    return true;
+                default:
+                    return false;
+            }
+            break;
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_CONCAT:
+        case GGML_OP_ADD:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_SCALE:
+        case GGML_OP_SQR:
+        case GGML_OP_SUM_ROWS:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_NORM:
+        case GGML_OP_ALIBI:
+        case GGML_OP_ROPE:
+        case GGML_OP_IM2COL:
+        case GGML_OP_ARGSORT:
+        case GGML_OP_DUP:
+        case GGML_OP_CPY:
+        case GGML_OP_CONT:
+            return true;
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_GET_ROWS:
+            {
+                return op->ne[0] % 4 == 0;
+            } break;
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_MUL_MAT_ID:
+            {
+                struct ggml_tensor * a;
+                struct ggml_tensor * b; UNUSED(b);
+                if (op->op == GGML_OP_MUL_MAT) {
+                    a = op->src[0];
+                    b = op->src[1];
+                } else {
+                    a = op->src[2];
+                    b = op->src[1];
+                }
+                if (a->ne[3] != 1) {
+                    return false;
+                }
+                if (ggml_is_quantized(a->type) && a->ne[2] != 1) {
+                    return false;
+                }
+                return true;
+            } break;
+        default:
+            return false;
+    }
+}
 void ggml_metal_graph_compute(
         struct ggml_metal_context * ctx,
                struct ggml_cgraph * gf) {
@@ -843,6 +905,8 @@ void ggml_metal_graph_compute(
                         } break;
                 }
 
+                GGML_ASSERT(ggml_metal_supports_op(dst));
+
                 const int64_t  ne00 = src0 ? src0->ne[0] : 0;
                 const int64_t  ne01 = src0 ? src0->ne[1] : 0;
                 const int64_t  ne02 = src0 ? src0->ne[2] : 0;
@@ -1973,70 +2037,7 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
 }
 
 static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
-    switch (op->op) {
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_SILU:
-                case GGML_UNARY_OP_RELU:
-                case GGML_UNARY_OP_GELU:
-                    return true;
-                default:
-                    return false;
-            }
-            break;
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_TRANSPOSE:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_CONCAT:
-        case GGML_OP_ADD:
-        case GGML_OP_MUL:
-        case GGML_OP_DIV:
-        case GGML_OP_SCALE:
-        case GGML_OP_SQR:
-        case GGML_OP_SUM_ROWS:
-        case GGML_OP_SOFT_MAX:
-        case GGML_OP_RMS_NORM:
-        case GGML_OP_NORM:
-        case GGML_OP_ALIBI:
-        case GGML_OP_ROPE:
-        case GGML_OP_IM2COL:
-        case GGML_OP_ARGSORT:
-        case GGML_OP_DUP:
-        case GGML_OP_CPY:
-        case GGML_OP_CONT:
-            return true;
-        case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_GET_ROWS:
-            {
-                // TODO: also check during graph_compute
-                return op->ne[0] % 4 == 0;
-            } break;
-        case GGML_OP_MUL_MAT:
-        case GGML_OP_MUL_MAT_ID:
-            {
-                // TODO: also check during graph_compute
-                struct ggml_tensor * a;
-                struct ggml_tensor * b; UNUSED(b);
-                if (op->op == GGML_OP_MUL_MAT) {
-                    a = op->src[0];
-                    b = op->src[1];
-                } else {
-                    a = op->src[2];
-                    b = op->src[1];
-                }
-                if (a->ne[3] != 1) {
-                    return false;
-                }
-                if (ggml_is_quantized(a->type) && a->ne[2] != 1) {
-                    return false;
-                }
-                return true;
-            } break;
-        default:
-            return false;
-    }
+    return ggml_metal_supports_op(op);
 
     UNUSED(backend);
 }