]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add upscale (#20284)
authorGeorgi Gerganov <redacted>
Mon, 9 Mar 2026 14:45:11 +0000 (16:45 +0200)
committerGitHub <redacted>
Mon, 9 Mar 2026 14:45:11 +0000 (16:45 +0200)
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal.metal

index 06f3d80459069cc65a0ce26da732da3418bd9a31..169c63dd7a457bb4ccdbd421669ac9c49434d8f8 100644 (file)
@@ -1717,12 +1717,29 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met
     char base[256];
     char name[256];
 
-    snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
-    snprintf(name, 256, "%s", base);
+    const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
+    const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
+
+    const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
+
+    if (mode == GGML_SCALE_MODE_BILINEAR) {
+        snprintf(base, 256, "kernel_upscale_bilinear_%s", ggml_type_name(op->src[0]->type));
+    } else if (mode == GGML_SCALE_MODE_BICUBIC) {
+        snprintf(base, 256, "kernel_upscale_bicubic_%s", ggml_type_name(op->src[0]->type));
+    } else {
+        snprintf(base, 256, "kernel_upscale_nearest_%s", ggml_type_name(op->src[0]->type));
+    }
+    snprintf(name, 256, "%s_aa=%d", base, antialias);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
-        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
     }
 
     return res;
index 4cce414abfef44c1e75c48e4ddf0cf591301af5c..23bd2b2ab72e620a7192561cb567ef5079cbdfdc 100644 (file)
@@ -1108,7 +1108,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                    op->type == GGML_TYPE_F32 &&
                    (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
         case GGML_OP_UPSCALE:
-            return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
+            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_POOL_1D:
             return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_POOL_2D:
index 383e0d6e93b5728e752dcae59898ecaf268a39eb..bf51055e36751dcfacb33290736495237be87d54 100644 (file)
@@ -83,6 +83,7 @@
 #define FC_UNARY                       1200
 #define FC_BIN                         1300
 #define FC_SUM_ROWS                    1400
+#define FC_UPSCALE                     1500
 
 // op-specific constants
 #define OP_FLASH_ATTN_EXT_NQPSG 8
@@ -890,6 +891,7 @@ typedef struct {
     float    sf1;
     float    sf2;
     float    sf3;
+    float    poffs;
 } ggml_metal_kargs_upscale;
 
 typedef struct {
index b3390352ffcfbf3bd335efb607d5092ef2d3534a..524e1116629827b5481aa2eb0d0df0d91d9aaf07 100644 (file)
@@ -3729,32 +3729,43 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    const float sf0 = (float)ne0/op->src[0]->ne[0];
-    const float sf1 = (float)ne1/op->src[0]->ne[1];
-    const float sf2 = (float)ne2/op->src[0]->ne[2];
-    const float sf3 = (float)ne3/op->src[0]->ne[3];
+    float sf0 = (float)ne0/op->src[0]->ne[0];
+    float sf1 = (float)ne1/op->src[0]->ne[1];
+    float sf2 = (float)ne2/op->src[0]->ne[2];
+    float sf3 = (float)ne3/op->src[0]->ne[3];
+
+    const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
+
+    float poffs = 0.5f;
+
+    if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
+        poffs = 0.0f;
+        sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
+        sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
+    }
 
     ggml_metal_kargs_upscale args = {
-        /*.ne00 =*/ ne00,
-        /*.ne01 =*/ ne01,
-        /*.ne02 =*/ ne02,
-        /*.ne03 =*/ ne03,
-        /*.nb00 =*/ nb00,
-        /*.nb01 =*/ nb01,
-        /*.nb02 =*/ nb02,
-        /*.nb03 =*/ nb03,
-        /*.ne0 =*/ ne0,
-        /*.ne1 =*/ ne1,
-        /*.ne2 =*/ ne2,
-        /*.ne3 =*/ ne3,
-        /*.nb0 =*/ nb0,
-        /*.nb1 =*/ nb1,
-        /*.nb2 =*/ nb2,
-        /*.nb3 =*/ nb3,
-        /*.sf0 =*/ sf0,
-        /*.sf1 =*/ sf1,
-        /*.sf2 =*/ sf2,
-        /*.sf3 =*/ sf3
+        /*.ne00  =*/ ne00,
+        /*.ne01  =*/ ne01,
+        /*.ne02  =*/ ne02,
+        /*.ne03  =*/ ne03,
+        /*.nb00  =*/ nb00,
+        /*.nb01  =*/ nb01,
+        /*.nb02  =*/ nb02,
+        /*.nb03  =*/ nb03,
+        /*.ne0   =*/ ne0,
+        /*.ne1   =*/ ne1,
+        /*.ne2   =*/ ne2,
+        /*.ne3   =*/ ne3,
+        /*.nb0   =*/ nb0,
+        /*.nb1   =*/ nb1,
+        /*.nb2   =*/ nb2,
+        /*.nb3   =*/ nb3,
+        /*.sf0   =*/ sf0,
+        /*.sf1   =*/ sf1,
+        /*.sf2   =*/ sf2,
+        /*.sf3   =*/ sf3,
+        /*.poffs =*/ poffs,
     };
 
     auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
index a58e641ad86b834622366f1d329765d924e50394..5cfd69dd86614789693e55883a6135c438c2e782 100644 (file)
@@ -4530,7 +4530,9 @@ kernel void kernel_conv_transpose_2d<half>(
     uint3   tpitg[[thread_position_in_threadgroup]],
     uint3     ntg[[threads_per_threadgroup]]);
 
-kernel void kernel_upscale_f32(
+constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
+
+kernel void kernel_upscale_nearest_f32(
     constant ggml_metal_kargs_upscale & args,
     device  const char * src0,
     device        char * dst,
@@ -4556,6 +4558,156 @@ kernel void kernel_upscale_f32(
     }
 }
 
+static inline float bilinear_tri(float x) {
+    return MAX(0.0f, 1.0f - fabs(x));
+}
+
+kernel void kernel_upscale_bilinear_f32(
+    constant ggml_metal_kargs_upscale & args,
+    device  const char * src0,
+    device        char * dst,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3 / args.sf3;
+    const int64_t i02 = i2 / args.sf2;
+
+    const float   f01  = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
+    const int64_t i01  = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
+    const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
+    const float   fd1  = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
+
+    src0 += i03*args.nb03 + i02*args.nb02;
+
+    device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
+
+    if (FC_upscale_aa) {
+        const float support0  = MAX(1.0f, 1.0f / args.sf0);
+        const float invscale0 = 1.0f / support0;
+        const float support1  = MAX(1.0f, 1.0f / args.sf1);
+        const float invscale1 = 1.0f / support1;
+
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
+
+            int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
+            int64_t x_max = MIN(args.ne00,  (int64_t)ceil (f00 + support0 + args.poffs));
+
+            int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
+            int64_t y_max = MIN(args.ne01,  (int64_t)ceil (f01 + support1 + args.poffs));
+
+            float sum = 0.0f;
+            float wsum = 0.0f;
+
+            for (int64_t sy = y_min; sy < y_max; ++sy) {
+                const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
+                for (int64_t sx = x_min; sx < x_max; ++sx) {
+                    const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
+                    const float w  = wx * wy;
+                    const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
+                    sum  += (*src_ptr) * w;
+                    wsum += w;
+                }
+            }
+
+            const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
+            dst_ptr[i0] = v;
+        }
+    } else {
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            const float   f00  = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
+            const int64_t i00  = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
+            const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
+            const float   fd0  = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
+
+            device const float * src00 = (device const float *)(src0 + i01*args.nb01  + i00*args.nb00);
+            device const float * src10 = (device const float *)(src0 + i01*args.nb01  + i00p*args.nb00);
+            device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
+            device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
+
+            const float v =
+                (*src00) * (1.0f - fd0) * (1.0f - fd1) +
+                (*src10) * fd0          * (1.0f - fd1) +
+                (*src01) * (1.0f - fd0) * fd1 +
+                (*src11) * fd0          * fd1;
+
+            dst_ptr[i0] = v;
+        }
+    }
+}
+
+static inline float bicubic_weight1(float x) {
+    const float a = -0.75f;
+    return ((a + 2) * x - (a + 3)) * x * x + 1;
+}
+
+static inline float bicubic_weight2(float x) {
+    const float a = -0.75f;
+    return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
+}
+
+kernel void kernel_upscale_bicubic_f32(
+    constant ggml_metal_kargs_upscale & args,
+    device  const char * src0,
+    device        char * dst,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3 / args.sf3;
+    const int64_t i02 = i2 / args.sf2;
+
+    const float   f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
+    const int64_t i01 = (int64_t)floor(f01);
+    const float   fd1 = f01 - (float)i01;
+
+    const float w_y0 = bicubic_weight2(fd1 + 1.0f);
+    const float w_y1 = bicubic_weight1(fd1);
+    const float w_y2 = bicubic_weight1(1.0f - fd1);
+    const float w_y3 = bicubic_weight2(2.0f - fd1);
+
+    const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
+
+    device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const float   f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
+        const int64_t i00 = (int64_t)floor(f00);
+        const float   fd0 = f00 - (float)i00;
+
+        const float w_x0 = bicubic_weight2(fd0 + 1.0f);
+        const float w_x1 = bicubic_weight1(fd0);
+        const float w_x2 = bicubic_weight1(1.0f - fd0);
+        const float w_x3 = bicubic_weight2(2.0f - fd0);
+
+        float sum = 0.0f;
+
+        for (int dy = -1; dy <= 2; ++dy) {
+            const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
+            const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
+
+            for (int dx = -1; dx <= 2; ++dx) {
+                const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
+                const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
+
+                const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
+                sum += (*src_ptr) * wx * wy;
+            }
+        }
+
+        dst_ptr[i0] = sum;
+    }
+}
+
 kernel void kernel_pad_f32(
     constant ggml_metal_kargs_pad & args,
     device  const char * src0,