]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
model: LFM2-VL fixes (llama/17577)
authorTarek Dakhran <redacted>
Sun, 30 Nov 2025 20:57:31 +0000 (21:57 +0100)
committerGeorgi Gerganov <redacted>
Thu, 11 Dec 2025 13:32:50 +0000 (15:32 +0200)
* Adjust to pytorch

* Add antialiasing upscale

* Increase number of patches to 1024

* Handle default marker insertion for LFM2

* Switch to flag

* Reformat

* Cuda implementation of antialias kernel

* Change placement in ops.cpp

* consistent float literals

* Pad only for LFM2

* Address PR feedback

* Rollback default marker placement changes

* Fallback to CPU implementation for antialias implementation of upscale

include/ggml.h
src/ggml-cann/ggml-cann.cpp
src/ggml-cpu/ops.cpp
src/ggml-cuda/upscale.cu
src/ggml-metal/ggml-metal-device.m
src/ggml-opencl/ggml-opencl.cpp
src/ggml-sycl/ggml-sycl.cpp
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml.c
tests/test-backend-ops.cpp

index 4dbca868bc74a8cdafa6f1a6fd621a00cb51454a..48da68fe7e3eeaf8d4411847cce245b26a6f85c5 100644 (file)
@@ -2148,7 +2148,8 @@ extern "C" {
     };
 
     enum ggml_scale_flag {
-        GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8)
+        GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8),
+        GGML_SCALE_FLAG_ANTIALIAS     = (1 << 9),
     };
 
     // interpolate
index df28d67fb0b1af52a3a69204428fb0576cbb5ec3..cd1b5e5b944a6382d22e739ae8999abf9f46cd0a 100644 (file)
@@ -2500,6 +2500,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
                 if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
                     return false;
                 }
+                if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
+                    return false;
+                }
                 return true;
             }
         case GGML_OP_POOL_2D:
index 2745fc54e1595c71f3d69cf82dd4d13291847c49..608e82af69f497e89eda6a13f6dbcf216ba18eaa 100644 (file)
@@ -7420,6 +7420,65 @@ static void ggml_compute_forward_upscale_f32(
                 }
             }
         }
+    } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
+        // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
+        // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
+        auto triangle_filter = [](float x) -> float {
+            return std::max(1.0f - fabsf(x), 0.0f);
+        };
+
+        // support and invscale, minimum 1 pixel for bilinear
+        const float support1  = std::max(1.0f, 1.0f / sf1);
+        const float invscale1 = 1.0f / support1;
+        const float support0  = std::max(1.0f, 1.0f / sf0);
+        const float invscale0 = 1.0f / support0;
+
+        for (int64_t i3 = 0; i3 < ne3; i3++) {
+            const int64_t i03 = i3 / sf3;
+            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
+                const int64_t i02 = i2 / sf2;
+                for (int64_t i1 = 0; i1 < ne1; i1++) {
+                    const float y = ((float) i1 + pixel_offset) / sf1;
+                    for (int64_t i0 = 0; i0 < ne0; i0++) {
+                        const float x = ((float) i0 + pixel_offset) / sf0;
+
+                        // the range of source pixels that contribute
+                        const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
+                        const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
+                        const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
+                        const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
+
+                        // bilinear filter with antialiasing
+                        float val = 0.0f;
+                        float total_weight = 0.0f;
+
+                        for (int64_t sy = y_min; sy < y_max; sy++) {
+                            const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
+
+                            for (int64_t sx = x_min; sx < x_max; sx++) {
+                                const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
+                                const float weight = weight_x * weight_y;
+
+                                if (weight <= 0.0f) {
+                                    continue;
+                                }
+
+                                const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
+                                val += pixel * weight;
+                                total_weight += weight;
+                            }
+                        }
+
+                        if (total_weight > 0.0f) {
+                            val /= total_weight;
+                        }
+
+                        float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+                        *dst_ptr = val;
+                    }
+                }
+            }
+        }
     } else if (mode == GGML_SCALE_MODE_BILINEAR) {
         for (int64_t i3 = 0; i3 < ne3; i3++) {
             const int64_t i03 = i3 / sf3;
index 687c669304d8dbc9cca8badc75156d29b6f02e0c..6bdf3cd996bfc48555e24917778dd45eddf786e0 100644 (file)
@@ -81,6 +81,76 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
     dst[index] = result;
 }
 
+// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
+// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
+static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst,
+        const int nb00, const int nb01, const int nb02, const int nb03,
+        const int ne00_src, const int ne01_src,
+        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
+        const float sf0, const float sf1, const float sf2, const float sf3,
+        const float pixel_offset) {
+    const int64_t index              = threadIdx.x + blockIdx.x * blockDim.x;
+    const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
+
+    if (index >= dst_total_elements) {
+        return;
+    }
+
+    const int i10_dst = index % ne10_dst;
+    const int i11_dst = (index / ne10_dst) % ne11_dst;
+    const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
+    const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
+
+    const int i02_src = (int)(i12_dst / sf2);
+    const int i03_src = (int)(i13_dst / sf3);
+
+    const float y = ((float)i11_dst + pixel_offset) / sf1;
+    const float x = ((float)i10_dst + pixel_offset) / sf0;
+
+    // support and invscale, minimum 1 pixel for bilinear
+    const float support1  = max(1.0f / sf1, 1.0f);
+    const float invscale1 = 1.0f / support1;
+    const float support0  = max(1.0f / sf0, 1.0f);
+    const float invscale0 = 1.0f / support0;
+
+    // the range of source pixels that contribute
+    const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset));
+    const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset));
+    const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset));
+    const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset));
+
+    // bilinear filter with antialiasing
+    float val = 0.0f;
+    float total_weight = 0.0f;
+
+    auto triangle_filter = [](float x) -> float {
+        return max(1.0f - fabsf(x), 0.0f);
+    };
+
+    for (int64_t sy = y_min; sy < y_max; sy++) {
+        const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
+
+        for (int64_t sx = x_min; sx < x_max; sx++) {
+            const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
+            const float weight = weight_x * weight_y;
+
+            if (weight <= 0.0f) {
+                continue;
+            }
+
+            const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03);
+            val += pixel * weight;
+            total_weight += weight;
+        }
+    }
+
+    if (total_weight > 0.0f) {
+        val /= total_weight;
+    }
+
+    dst[index] = val;
+}
+
 namespace bicubic_interpolation {
 // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
 __device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
@@ -161,11 +231,15 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst,
         const int ne00_src, const int ne01_src,
         const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
         const float sf0, const float sf1, const float sf2, const float sf3,
-        const float pixel_offset, cudaStream_t stream) {
+        const float pixel_offset, bool antialias, cudaStream_t stream) {
     const int64_t dst_size   = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
     const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
 
-    upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
+    if (antialias) {
+        upscale_f32_bilinear_antialias<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
+    } else {
+        upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
+    }
 }
 
 static void upscale_f32_bicubic_cuda(const float * x, float * dst,
@@ -207,9 +281,10 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     if (mode == GGML_SCALE_MODE_NEAREST) {
         upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
     } else if (mode == GGML_SCALE_MODE_BILINEAR) {
+        const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
         upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
                                  src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
-                                 sf0, sf1, sf2, sf3, pixel_offset, stream);
+                                 sf0, sf1, sf2, sf3, pixel_offset, antialias, stream);
     } else if (mode == GGML_SCALE_MODE_BICUBIC) {
         upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
                                  src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
index 09b1b50311828c0b8b45d643f788f0694fc744b2..3aad16a3ff78be7c175741b90a752d0af38cc6f9 100644 (file)
@@ -894,7 +894,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_POOL_1D:
             return false;
         case GGML_OP_UPSCALE:
-            return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
+            return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
         case GGML_OP_POOL_2D:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_PAD:
index e5302f4550e05abe0e6ed8aa7904233eafa69ab2..277a30d30ed7845f1b4d56e45618cf028e2eb90f 100644 (file)
@@ -3086,8 +3086,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
             return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
         case GGML_OP_UPSCALE: {
             ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
+            const bool antialias = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & GGML_SCALE_FLAG_ANTIALIAS);
             return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
-                   (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR);
+                   (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) && !antialias;
         }
         case GGML_OP_CONV_2D:
             return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
index 3f1bdfb9f1bababd12ebeea1efd73e7451d3c28e..e82b51206e2a16ccd16faba82ad2c28c3ec8ee7e 100644 (file)
@@ -4597,7 +4597,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_IM2COL:
             return true;
         case GGML_OP_UPSCALE:
-            return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
+            return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
index 66dd0bfabd2be2b16b947529d9593828a7821157..95966ce1d8ee7b5a73a115d607646edc1d44e76c 100644 (file)
@@ -14113,6 +14113,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             }
             return true;
         case GGML_OP_UPSCALE:
+            return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
         case GGML_OP_ACC:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_CONCAT:
index b99345a2e93b024a2296bd7c29d7a7affbaf0e8c..17cf4d84bb8f7026e2f28b4d5d0076c1743ecc4e 100644 (file)
@@ -4891,6 +4891,8 @@ static struct ggml_tensor * ggml_interpolate_impl(
         int64_t               ne3,
         uint32_t              mode) {
     GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);
+    // TODO: implement antialias for modes other than bilinear
+    GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR);
 
     struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
 
index 87a61aa122445227f1a4e26666ab9881881d2372..9645d0b39097390f6b5ebb63fd87d1579bb9b206 100644 (file)
@@ -7660,7 +7660,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     //    test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1));
     //}
 
-    for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
+    for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC, ggml_scale_mode(GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)}) {
         test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
         test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
         test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5,  7, 11}, {5, 7, 11, 13}, mode));