]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: add bilinear interpolation for upscale (#14563)
authorAman Gupta <redacted>
Tue, 8 Jul 2025 02:11:18 +0000 (10:11 +0800)
committerGitHub <redacted>
Tue, 8 Jul 2025 02:11:18 +0000 (10:11 +0800)
ggml/include/ggml.h
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/upscale.cu

index 949eac9a5a0b596f7dbac2db876443e0a5dc1d60..76b0c2a9887278e875e07011ec3fc583e3c3f24d 100644 (file)
@@ -495,7 +495,7 @@ extern "C" {
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
         GGML_OP_POOL_2D_BACK,
-        GGML_OP_UPSCALE, // nearest interpolate
+        GGML_OP_UPSCALE,
         GGML_OP_PAD,
         GGML_OP_PAD_REFLECT_1D,
         GGML_OP_ROLL,
index b6b7960f12146e1e4d15fb000e11f105604a3c4c..da1e8f8f4e44302fabb663e201d6a5150d173912 100644 (file)
@@ -3375,7 +3375,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_GROUP_NORM:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_UPSCALE:
-            return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
         case GGML_OP_PAD:
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
index 524e97957426615ce58965ad0758c3f5dc6afb52..ef48aa5f97bcd1a2d2052e94ca6f8c33120c8091 100644 (file)
@@ -22,17 +22,88 @@ static __global__ void upscale_f32(const float * x, float * dst,
     dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
 }
 
+static __global__ void upscale_f32_bilinear(const float * x, float * dst,
+        const int nb00, const int nb01, const int nb02, const int nb03,
+        const int ne00_src, const int ne01_src,
+        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
+        const float sf0, const float sf1, const float sf2, const float sf3,
+        const float pixel_offset) {
+    const int64_t index              = threadIdx.x + blockIdx.x * blockDim.x;
+    const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
+
+    if (index >= dst_total_elements) {
+        return;
+    }
+
+    const int i10_dst = index % ne10_dst;
+    const int i11_dst = (index / ne10_dst) % ne11_dst;
+    const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
+    const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
+
+    const int i02_src = (int)(i12_dst / sf2);
+    const int i03_src = (int)(i13_dst / sf3);
+
+    const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
+    int y0_src    = (int)floorf(y_src_f);
+    int y1_src    = y0_src + 1;
+
+    y0_src = max(0, min(y0_src, ne01_src - 1));
+    y1_src = max(0, min(y1_src, ne01_src - 1));
+
+    float dy = y_src_f - (float)y0_src;
+    dy       = max(0.0f, min(dy, 1.0f));
+
+    float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
+    int x0_src    = (int)floorf(x_src_f);
+    int x1_src    = x0_src + 1;
+
+    x0_src = max(0, min(x0_src, ne00_src - 1));
+    x1_src = max(0, min(x1_src, ne00_src - 1));
+
+    float dx = x_src_f - (float)x0_src;
+    dx = max(0.0f, min(dx, 1.0f));
+
+    const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
+    const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
+    const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
+    const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
+
+    const float val_a = *p_a;
+    const float val_b = *p_b;
+    const float val_c = *p_c;
+    const float val_d = *p_d;
+
+    float result = val_a * (1.0f - dx) * (1.0f - dy) +
+                   val_b * dx * (1.0f - dy) +
+                   val_c * (1.0f - dx) * dy +
+                   val_d * dx * dy;
+
+    dst[index] = result;
+}
+
 static void upscale_f32_cuda(const float * x, float * dst,
         const int nb00, const int nb01, const int nb02, const int nb03,
         const int ne10, const int ne11, const int ne12, const int ne13,
         const float sf0, const float sf1, const float sf2, const float sf3,
         cudaStream_t stream) {
-    int dst_size = ne10 * ne11 * ne12 * ne13;
-    int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
+    const int64_t dst_size   = ne10 * ne11 * ne12 * ne13;
+    const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
 
     upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
 }
 
+static void upscale_f32_bilinear_cuda(const float * x, float * dst,
+        const int nb00, const int nb01, const int nb02, const int nb03,
+        const int ne00_src, const int ne01_src,
+        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
+        const float sf0, const float sf1, const float sf2, const float sf3,
+        const float pixel_offset, cudaStream_t stream) {
+    const int64_t dst_size   = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
+    const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
+
+    upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
+}
+
 void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
@@ -42,10 +113,25 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const float sf0 = (float)dst->ne[0]/src0->ne[0];
-    const float sf1 = (float)dst->ne[1]/src0->ne[1];
-    const float sf2 = (float)dst->ne[2]/src0->ne[2];
+    const int mode_flags = dst->op_params[0];
+    const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);
+
+    float sf0 = (float)dst->ne[0]/src0->ne[0];
+    float sf1 = (float)dst->ne[1]/src0->ne[1];
+    float sf2 = (float)dst->ne[2]/src0->ne[2];
     const float sf3 = (float)dst->ne[3]/src0->ne[3];
 
-    upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
+    if (mode == GGML_SCALE_MODE_NEAREST) {
+        upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
+    } else if (mode == GGML_SCALE_MODE_BILINEAR) {
+        float pixel_offset = 0.5f;
+        if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
+            sf0          = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
+            sf1          = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
+            pixel_offset = 0.0f;
+        }
+        upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                                 src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+                                 sf0, sf1, sf2, sf3, pixel_offset, stream);
+    }
 }