]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : add `ggml_upscale_ext` (#814)
authorJohn Balis <redacted>
Wed, 15 May 2024 08:52:33 +0000 (03:52 -0500)
committerGitHub <redacted>
Wed, 15 May 2024 08:52:33 +0000 (11:52 +0300)
* initial commit with CPU implementation of upscale to shape and test, cuda implementation next

* experimental commit to see if dst shape is correct

* test version

* test

* removed unnecessary params

* refactor

* fixed tests

* ggml : metal impl + cleanup + sycl dev warnings

* patched ggml_upscale cuda op to handle non-contiguous tensors, added test for non-contiguous behavior

* metal : fix upsacle op to support nb00 + style

---------

Co-authored-by: Georgi Gerganov <redacted>
include/ggml/ggml.h
src/ggml-cuda/upscale.cu
src/ggml-metal.m
src/ggml-metal.metal
src/ggml-sycl.cpp
src/ggml.c
tests/test-backend-ops.cpp

index 25f4f73a8d96b403a126532f62ce6c2294e4a93c..5e121604a80e6177f4db1912d630629092ef8fa3 100644 (file)
@@ -1674,12 +1674,24 @@ extern "C" {
             float                 p1);
 
     // nearest interpolate
+    // multiplies ne0 and ne1 by scale factor
     // used in stable-diffusion
     GGML_API struct ggml_tensor * ggml_upscale(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             int                   scale_factor);
 
+    // nearest interpolate
+    // nearest interpolate to specified dimensions
+    // used in tortoise.cpp
+    GGML_API struct ggml_tensor * ggml_upscale_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   ne0,
+            int                   ne1,
+            int                   ne2,
+            int                   ne3);
+
     // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
     GGML_API struct ggml_tensor * ggml_pad(
             struct ggml_context * ctx,
index 2f62fed488835ca97df4af0bd09521e8f36c0842..cf513c3ade7c4f0f7e75bf8a1f35f099978a0e06 100644 (file)
@@ -1,35 +1,36 @@
 #include "upscale.cuh"
 
-static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) {
-    // blockIdx.z: idx of ne02*ne03
-    // blockIdx.y: idx of ne01*scale_factor, aka ne1
-    // blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE
-    // ne00xne01: ne00 * ne01
-    int ne0 = ne00 * scale_factor;
-    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
-    if (nidx >= ne0) {
+static __global__ void upscale_f32(const float * x, float * dst,
+        const int nb00, const int nb01, const int nb02, const int nb03,
+        const int ne10, const int ne11, const int ne12, const int ne13,
+        const float sf0, const float sf1, const float sf2, const float sf3) {
+    int index = threadIdx.x + blockIdx.x * blockDim.x;
+    if (index >= ne10 * ne11 * ne12 * ne13) {
         return;
     }
-    // operation
-    int i00 = nidx / scale_factor;
-    int i01 = blockIdx.y / scale_factor;
-    int offset_src =
-        i00 +
-        i01 * ne00 +
-        blockIdx.z * ne00xne01;
-    int offset_dst =
-        nidx +
-        blockIdx.y * ne0 +
-        blockIdx.z * ne0 * gridDim.y;
-    dst[offset_dst] = x[offset_src];
+
+    int i10 = index % ne10;
+    int i11 = (index / ne10) % ne11;
+    int i12 = (index / (ne10 * ne11)) % ne12;
+    int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
+
+    int i00 = i10 / sf0;
+    int i01 = i11 / sf1;
+    int i02 = i12 / sf2;
+    int i03 = i13 / sf3;
+
+    dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
 }
 
-static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03,
-                             const int scale_factor, cudaStream_t stream) {
-    int ne0 = (ne00 * scale_factor);
-    int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
-    dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02*ne03);
-    upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
+static void upscale_f32_cuda(const float * x, float * dst,
+        const int nb00, const int nb01, const int nb02, const int nb03,
+        const int ne10, const int ne11, const int ne12, const int ne13,
+        const float sf0, const float sf1, const float sf2, const float sf3,
+        cudaStream_t stream) {
+    int dst_size = ne10 * ne11 * ne12 * ne13;
+    int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
+
+    upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
 }
 
 void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -39,10 +40,12 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const int scale_factor = dst->op_params[0];
+    const float sf0 = (float)dst->ne[0]/src0->ne[0];
+    const float sf1 = (float)dst->ne[1]/src0->ne[1];
+    const float sf2 = (float)dst->ne[2]/src0->ne[2];
+    const float sf3 = (float)dst->ne[3]/src0->ne[3];
 
-    upscale_f32_cuda(src0_d, dst_d, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], scale_factor, stream);
+    upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
 }
index 390a1cd789081d36f55c358a9c57fa0ac6d998e7..b0b16dbf7716096a9c0f7cab1bee092500525289 100644 (file)
@@ -2353,7 +2353,10 @@ static enum ggml_status ggml_metal_graph_compute(
                     {
                         GGML_ASSERT(src0->type == GGML_TYPE_F32);
 
-                        const int sf = dst->op_params[0];
+                        const float sf0 = (float)ne0/src0->ne[0];
+                        const float sf1 = (float)ne1/src0->ne[1];
+                        const float sf2 = (float)ne2/src0->ne[2];
+                        const float sf3 = (float)ne3/src0->ne[3];
 
                         const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
 
@@ -2376,7 +2379,10 @@ static enum ggml_status ggml_metal_graph_compute(
                         [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
                         [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
                         [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
-                        [encoder setBytes:&sf   length:sizeof(sf)   atIndex:18];
+                        [encoder setBytes:&sf0  length:sizeof(sf0)  atIndex:18];
+                        [encoder setBytes:&sf1  length:sizeof(sf1)  atIndex:19];
+                        [encoder setBytes:&sf2  length:sizeof(sf2)  atIndex:20];
+                        [encoder setBytes:&sf3  length:sizeof(sf3)  atIndex:21];
 
                         const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
 
index 57fdf564e17f6b5444607f01b50385cc998b0dda..386e9195fcffa5ecde003ab2d896955f7ca68264 100644 (file)
@@ -1852,7 +1852,10 @@ kernel void kernel_upscale_f32(
     constant  uint64_t & nb1,
     constant  uint64_t & nb2,
     constant  uint64_t & nb3,
-    constant   int32_t & sf,
+    constant     float & sf0,
+    constant     float & sf1,
+    constant     float & sf2,
+    constant     float & sf3,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
@@ -1861,15 +1864,17 @@ kernel void kernel_upscale_f32(
     const int64_t i2 = tgpig.y;
     const int64_t i1 = tgpig.x;
 
-    const int64_t i03 = i3;
-    const int64_t i02 = i2;
-    const int64_t i01 = i1/sf;
-
-    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
-    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+    const int64_t i03 = i3/sf3;
+    const int64_t i02 = i2/sf2;
+    const int64_t i01 = i1/sf1;
 
     for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        dst_ptr[i0] = src0_ptr[i0/sf];
+        const int64_t i00 = i0/sf0;
+
+        device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1  +  i0*nb0);
+
+        dst_ptr[0] = src0_ptr[0];
     }
 }
 
index 724070eb91080549184f9a1756a77c37fbfd3833..b15efb70421dadd53693be8e5ea7fd187525b346 100644 (file)
@@ -13987,6 +13987,10 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0,
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
     GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
 
+#pragma message("TODO: generalize upscale operator")
+#pragma message("      https://github.com/ggerganov/ggml/pull/814")
+    GGML_ASSERT(false && "TODO: generalize upscale operator);
+
     const int scale_factor = dst->op_params[0];
 
     upscale_f32_sycl(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
index 03b609dddce1065392ade7a64c0eb88594997309..f09cc30609bfd7d189a7fea50438853934b327ac 100644 (file)
@@ -6293,7 +6293,10 @@ struct ggml_tensor * ggml_pool_2d(
 static struct ggml_tensor * ggml_upscale_impl(
     struct ggml_context * ctx,
     struct ggml_tensor * a,
-    int scale_factor) {
+    int ne0,
+    int ne1,
+    int ne2,
+    int ne3) {
     bool is_node = false;
 
     if (a->grad) {
@@ -6301,19 +6304,45 @@ static struct ggml_tensor * ggml_upscale_impl(
         is_node = true;
     }
 
+    GGML_ASSERT(a->ne[0] <= ne0);
+    GGML_ASSERT(a->ne[1] <= ne1);
+    GGML_ASSERT(a->ne[2] <= ne2);
+    GGML_ASSERT(a->ne[3] <= ne3);
+
     struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
-            a->ne[0] * scale_factor,
-            a->ne[1] * scale_factor,
-            a->ne[2], a->ne[3]);
+            ne0,
+            ne1,
+            ne2,
+            ne3
+            );
 
     result->op = GGML_OP_UPSCALE;
-    result->op_params[0] = scale_factor;
+
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
 
     return result;
 }
 
+struct ggml_tensor * ggml_upscale(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    int scale_factor) {
+    return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
+}
+
+struct ggml_tensor * ggml_upscale_ext(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    int ne0,
+    int ne1,
+    int ne2,
+    int ne3) {
+    return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
+}
+
+// ggml_pad
+
 struct ggml_tensor * ggml_pad(
     struct ggml_context * ctx,
     struct ggml_tensor  * a,
@@ -6338,12 +6367,7 @@ struct ggml_tensor * ggml_pad(
     return result;
 }
 
-struct ggml_tensor * ggml_upscale(
-    struct ggml_context * ctx,
-    struct ggml_tensor * a,
-    int scale_factor) {
-    return ggml_upscale_impl(ctx, a, scale_factor);
-}
+// ggml_arange
 
 struct ggml_tensor * ggml_arange(
     struct ggml_context * ctx,
@@ -6365,6 +6389,8 @@ struct ggml_tensor * ggml_arange(
     return result;
 }
 
+// ggml_timestep_embedding
+
 struct ggml_tensor * ggml_timestep_embedding(
             struct ggml_context * ctx,
             struct ggml_tensor  * timesteps,
@@ -14820,25 +14846,28 @@ static void ggml_compute_forward_upscale_f32(
         return;
     }
 
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 
     const int ith = params->ith;
     const int nth = params->nth;
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
-    const int scale_factor = dst->op_params[0];
+    const float sf0 = (float)ne0/src0->ne[0];
+    const float sf1 = (float)ne1/src0->ne[1];
+    const float sf2 = (float)ne2/src0->ne[2];
+    const float sf3 = (float)ne3/src0->ne[3];
 
     // TODO: optimize
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
-        const int64_t i03 = i3;
+        const int64_t i03 = i3 / sf3;
         for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
-            const int64_t i02 = i2;
+            const int64_t i02 = i2 / sf2;
             for (int64_t i1 = 0; i1 < ne1; i1++) {
-                const int64_t i01 = i1 / scale_factor;
+                const int64_t i01 = i1 / sf1;
                 for (int64_t i0 = 0; i0 < ne0; i0++) {
-                    const int64_t i00 = i0 / scale_factor;
+                    const int64_t i00 = i0 / sf0;
 
                     const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
                           float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
@@ -14868,6 +14897,7 @@ static void ggml_compute_forward_upscale(
     }
 }
 
+
 // ggml_compute_forward_pad
 
 static void ggml_compute_forward_pad_f32(
index f080f7e22cd35d42e56d953955d1a5dc914e10df..85ef21c2a5c19cfffa6a8e12c0b5246e3a2f27a2 100644 (file)
@@ -1329,23 +1329,47 @@ struct test_upscale : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
     const int32_t scale_factor;
+    const bool transpose;
 
     std::string vars() override {
-        return VARS_TO_STR3(type, ne, scale_factor);
+        return VARS_TO_STR4(type, ne, scale_factor, transpose);
     }
 
     test_upscale(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {512, 512, 3, 1},
-            int32_t scale_factor = 2)
-        : type(type), ne(ne), scale_factor(scale_factor) {}
+            int32_t scale_factor = 2, bool transpose = false)
+        : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        if (transpose) a = ggml_transpose(ctx, a);
         ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
         return out;
     }
 };
 
+// GGML_OP_UPSCALE (ext)
+struct test_upscale_ext : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const std::array<int64_t, 4> ne_tgt;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, ne_tgt);
+    }
+
+    test_upscale_ext(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne     = {2, 5,  7, 11},
+            std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13})
+        : type(type), ne(ne), ne_tgt(ne_tgt) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]);
+        return out;
+    }
+};
+
 // GGML_OP_GROUP_NORM
 struct test_group_norm : public test_case {
     const ggml_type type;
@@ -2169,6 +2193,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     test_cases.emplace_back(new test_sum_rows());
     test_cases.emplace_back(new test_upscale());
+    test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
+    test_cases.emplace_back(new test_upscale_ext());
     test_cases.emplace_back(new test_group_norm());
     test_cases.emplace_back(new test_acc());
     test_cases.emplace_back(new test_pad());