]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : add FLOOR, CEIL, ROUND, TRUNC unary ops (llama/20930)
authornuri <redacted>
Tue, 24 Mar 2026 08:13:07 +0000 (17:13 +0900)
committerGeorgi Gerganov <redacted>
Sat, 28 Mar 2026 11:39:09 +0000 (13:39 +0200)
Co-authored-by: nryoo <redacted>
src/ggml-metal/ggml-metal-device.cpp
src/ggml-metal/ggml-metal-device.m
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal.metal

index 9162342ee98abeb390f25de840fbc4064ddc1c81..89539bd7615c6ded5fc74c552e13d02dcc8cda0b 100644 (file)
@@ -246,6 +246,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal
                 case GGML_UNARY_OP_EXP:         op_num = OP_UNARY_NUM_EXP;         break;
                 case GGML_UNARY_OP_SOFTPLUS:    op_num = OP_UNARY_NUM_SOFTPLUS;    break;
                 case GGML_UNARY_OP_EXPM1:       op_num = OP_UNARY_NUM_EXPM1;       break;
+                case GGML_UNARY_OP_FLOOR:       op_num = OP_UNARY_NUM_FLOOR;       break;
+                case GGML_UNARY_OP_CEIL:        op_num = OP_UNARY_NUM_CEIL;        break;
+                case GGML_UNARY_OP_ROUND:       op_num = OP_UNARY_NUM_ROUND;       break;
+                case GGML_UNARY_OP_TRUNC:       op_num = OP_UNARY_NUM_TRUNC;       break;
                 default: GGML_ABORT("fatal error");
             } break;
         default: GGML_ABORT("fatal error");
index 2fbb274c5f92f32813fa78e193acac164c4c1add..cbef2fb4879ee82c9b2c0339cd828210782ed4d9 100644 (file)
@@ -1039,6 +1039,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 case GGML_UNARY_OP_EXP:
                 case GGML_UNARY_OP_SOFTPLUS:
                 case GGML_UNARY_OP_EXPM1:
+                case GGML_UNARY_OP_FLOOR:
+                case GGML_UNARY_OP_CEIL:
+                case GGML_UNARY_OP_ROUND:
+                case GGML_UNARY_OP_TRUNC:
                     return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
                 default:
                     return false;
index ea471090cd8099439239bf5d9526cf9fe89c57a5..eb2253e029ad253dd628bb205935df345d09b249 100644 (file)
 #define OP_UNARY_NUM_EXP         114
 #define OP_UNARY_NUM_SOFTPLUS    115
 #define OP_UNARY_NUM_EXPM1       116
+#define OP_UNARY_NUM_FLOOR       117
+#define OP_UNARY_NUM_CEIL        118
+#define OP_UNARY_NUM_ROUND       119
+#define OP_UNARY_NUM_TRUNC       120
 
 #define OP_SUM_ROWS_NUM_SUM_ROWS 10
 #define OP_SUM_ROWS_NUM_MEAN     11
index 9286675189d15f514dbbb61a1c08a6569ec42930..2074211594ce53f100615347640aad62212d73f8 100644 (file)
@@ -1094,6 +1094,22 @@ kernel void kernel_unary_impl(
             // TODO: precise implementation
             dst_ptr[i0] = (T) (exp(x) - 1);
         }
+
+        if (FC_OP == OP_UNARY_NUM_FLOOR) {
+            dst_ptr[i0] = (T) floor(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_CEIL) {
+            dst_ptr[i0] = (T) ceil(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_ROUND) {
+            dst_ptr[i0] = (T) round(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_TRUNC) {
+            dst_ptr[i0] = (T) trunc(x);
+        }
     }
 
 #undef FC_OP