]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
yolo : add backend support (ggml/924)
authorRadoslav Gerganov <redacted>
Mon, 19 Aug 2024 07:09:33 +0000 (10:09 +0300)
committerGeorgi Gerganov <redacted>
Wed, 21 Aug 2024 08:07:13 +0000 (11:07 +0300)
* yolo : add backend support

* metal : add sub and sqrt kernels

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/binbcast.cu
ggml/src/ggml-cuda/binbcast.cuh
ggml/src/ggml-metal.m
ggml/src/ggml-metal.metal

index 8ff154f729ed612f8bc6241fe09348a3936dfddd..56c16a3c461ca90d7e49f14dca66d7470c898b7e 100644 (file)
@@ -2181,6 +2181,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_ADD:
             ggml_cuda_op_add(ctx, dst);
             break;
+        case GGML_OP_SUB:
+            ggml_cuda_op_sub(ctx, dst);
+            break;
         case GGML_OP_ACC:
             ggml_cuda_op_acc(ctx, dst);
             break;
@@ -2859,6 +2862,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_TRANSPOSE:
         case GGML_OP_NORM:
         case GGML_OP_ADD:
+        case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_RMS_NORM:
index 34bc67acdd890c077ae15de763435bab09ff0f2c..e1390a0414559fca9bc0a7ae05fd394de3c65615 100644 (file)
@@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) {
     return a + b;
 }
 
+static __device__ __forceinline__ float op_sub(const float a, const float b) {
+    return a - b;
+}
+
 static __device__ __forceinline__ float op_mul(const float a, const float b) {
     return a * b;
 }
@@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
 }
 
+void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
 void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
 }
index 4f63d6372eb50e717f47f7fc4844d90719119f82..198c9ef6fd8ea73c3e9e85f5ef0e60676365ca91 100644 (file)
@@ -2,5 +2,6 @@
 
 void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index f6bd6e3407e54f093218e4eedcb51a8585842860..7950c0dccb7f301c8f923bbaa16b1826102f7ba9 100644 (file)
@@ -31,6 +31,8 @@ struct ggml_metal_kernel {
 enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_ADD,
     GGML_METAL_KERNEL_TYPE_ADD_ROW,
+    GGML_METAL_KERNEL_TYPE_SUB,
+    GGML_METAL_KERNEL_TYPE_SUB_ROW,
     GGML_METAL_KERNEL_TYPE_MUL,
     GGML_METAL_KERNEL_TYPE_MUL_ROW,
     GGML_METAL_KERNEL_TYPE_DIV,
@@ -205,6 +207,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_CONCAT,
     GGML_METAL_KERNEL_TYPE_SQR,
+    GGML_METAL_KERNEL_TYPE_SQRT,
     GGML_METAL_KERNEL_TYPE_SIN,
     GGML_METAL_KERNEL_TYPE_COS,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -493,6 +496,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
 
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                           add,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                       add_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB,                           sub,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW,                       sub_row,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                           mul,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                       mul_row,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                           div,                            true);
@@ -667,6 +672,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
@@ -769,6 +775,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
         case GGML_OP_PERMUTE:
         case GGML_OP_CONCAT:
         case GGML_OP_ADD:
+        case GGML_OP_SUB:
         case GGML_OP_ACC:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
@@ -777,6 +784,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
         case GGML_OP_CLAMP:
             return true;
         case GGML_OP_SQR:
+        case GGML_OP_SQRT:
         case GGML_OP_SIN:
         case GGML_OP_COS:
             return ggml_is_contiguous(op->src[0]);
@@ -1057,6 +1065,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                     } break;
                 case GGML_OP_ADD:
+                case GGML_OP_SUB:
                 case GGML_OP_MUL:
                 case GGML_OP_DIV:
                     {
@@ -1080,6 +1089,7 @@ static enum ggml_status ggml_metal_graph_compute(
                             nb = ne00 / 4;
                             switch (dst->op) {
                                 case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
+                                case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
                                 case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
                                 case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
                                 default: GGML_ABORT("fatal error");
@@ -1089,6 +1099,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         } else {
                             switch (dst->op) {
                                 case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
+                                case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
                                 case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
                                 case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
                                 default: GGML_ABORT("fatal error");
@@ -1416,6 +1427,20 @@ static enum ggml_status ggml_metal_graph_compute(
 
                         const int64_t n = ggml_nelements(dst);
 
+                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    } break;
+                case GGML_OP_SQRT:
+                    {
+                        GGML_ASSERT(ggml_is_contiguous(src0));
+
+                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                        [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                        const int64_t n = ggml_nelements(dst);
+
                         [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                     } break;
                 case GGML_OP_SIN:
index 3e4b685bb51c4aa43cf79744bbbddd82752138a1..17432085c03ad7e6b9e93554665b277aa9c9c133 100644 (file)
@@ -17,7 +17,7 @@ enum ggml_sort_order {
     GGML_SORT_ORDER_DESC,
 };
 
-// general-purpose kernel for addition, multiplication and division of two tensors
+// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
 // pros: works for non-contiguous tensors, supports broadcast across all dims
 // cons: not very efficient
 kernel void kernel_add(
@@ -70,6 +70,56 @@ kernel void kernel_add(
     }
 }
 
+kernel void kernel_sub(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        constant  int64_t & offs,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
+
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
+    }
+}
+
 kernel void kernel_mul(
         device const char * src0,
         device const char * src1,
@@ -226,6 +276,15 @@ kernel void kernel_add_row(
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
 
+kernel void kernel_sub_row(
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant   uint64_t & nb [[buffer(28)]],
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] - src1[tpig % nb];
+}
+
 kernel void kernel_mul_row(
         device const float4 * src0,
         device const float4 * src1,
@@ -358,6 +417,13 @@ kernel void kernel_sqr(
     dst[tpig] = src0[tpig] * src0[tpig];
 }
 
+kernel void kernel_sqrt(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = sqrt(src0[tpig]);
+}
+
 kernel void kernel_sin(
         device const float * src0,
         device       float * dst,