]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
yolo : add backend support (#924)
authorRadoslav Gerganov <redacted>
Mon, 19 Aug 2024 07:09:33 +0000 (10:09 +0300)
committerGitHub <redacted>
Mon, 19 Aug 2024 07:09:33 +0000 (10:09 +0300)
* yolo : add backend support

* metal : add sub and sqrt kernels

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/yolo/README.md
examples/yolo/yolov3-tiny.cpp
src/ggml-cuda.cu
src/ggml-cuda/binbcast.cu
src/ggml-cuda/binbcast.cuh
src/ggml-metal.m
src/ggml-metal.metal

index 0e69dc996c56f43a05079f355e4aeceb88ff306c..d2ced38cac574ba271bcd165167e9ba4cd718521 100644 (file)
@@ -17,11 +17,18 @@ $ ./convert-yolov3-tiny.py yolov3-tiny.weights
 yolov3-tiny.weights converted to yolov3-tiny.gguf
 ```
 
+Alternatively, you can download the converted model from [HuggingFace](https://huggingface.co/rgerganov/yolo-gguf/resolve/main/yolov3-tiny.gguf)
+
 Object detection:
 
 ```bash
 $ wget https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg
 $ ./yolov3-tiny -m yolov3-tiny.gguf -i dog.jpg
+load_model: using CUDA backend
+ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
+ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
+ggml_cuda_init: found 1 CUDA devices:
+  Device 0: NVIDIA T1200 Laptop GPU, compute capability 7.5, VMM: yes
 Layer  0 output shape:  416 x 416 x   16 x   1
 Layer  1 output shape:  208 x 208 x   16 x   1
 Layer  2 output shape:  208 x 208 x   32 x   1
@@ -48,5 +55,5 @@ car: 52%
 truck: 56%
 car: 62%
 bicycle: 59%
-Detected objects saved in 'predictions.jpg' (time: 0.357000 sec.)
+Detected objects saved in 'predictions.jpg' (time: 0.057000 sec.)
 ```
\ No newline at end of file
index ae7c2a00537399c18cc8d6ba09589b59dad1fce7..369e9efbcd116f7d1867a1b84b32911452cf854a 100644 (file)
@@ -1,4 +1,15 @@
 #include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+
+#ifdef GGML_USE_CUDA
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
 #include "yolo-image.h"
 
 #include <cmath>
@@ -29,6 +40,8 @@ struct yolo_model {
     int width = 416;
     int height = 416;
     std::vector<conv2d_layer> conv2d_layers;
+    ggml_backend_t backend = NULL;
+    ggml_backend_buffer_t buffer;
     struct ggml_context * ctx;
 };
 
@@ -36,15 +49,20 @@ struct yolo_layer {
     int classes = 80;
     std::vector<int> mask;
     std::vector<float> anchors;
-    struct ggml_tensor * predictions;
-
-    yolo_layer(int classes, const std::vector<int> & mask, const std::vector<float> & anchors, struct ggml_tensor * predictions)
-        : classes(classes), mask(mask), anchors(anchors), predictions(predictions)
-    { }
+    std::vector<float> predictions;
+    int w;
+    int h;
+
+    yolo_layer(int classes, const std::vector<int> & mask, const std::vector<float> & anchors, struct ggml_tensor * prev_layer)
+        : classes(classes), mask(mask), anchors(anchors)
+    {
+        w = prev_layer->ne[0];
+        h = prev_layer->ne[1];
+        predictions.resize(ggml_nbytes(prev_layer)/sizeof(float));
+        ggml_backend_tensor_get(prev_layer, predictions.data(), 0, ggml_nbytes(prev_layer));
+    }
 
     int entry_index(int location, int entry) const {
-        int w = predictions->ne[0];
-        int h = predictions->ne[1];
         int n = location / (w*h);
         int loc = location % (w*h);
         return n*w*h*(4+classes+1) + entry*w*h + loc;
@@ -62,15 +80,60 @@ struct detection {
 };
 
 static bool load_model(const std::string & fname, yolo_model & model) {
-    struct gguf_init_params params = {
+    // initialize the backend
+#ifdef GGML_USE_CUDA
+    fprintf(stderr, "%s: using CUDA backend\n", __func__);
+    model.backend = ggml_backend_cuda_init(0); // init device 0
+    if (!model.backend) {
+        fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+    }
+#endif
+
+#ifdef GGML_USE_METAL
+    fprintf(stderr, "%s: using Metal backend\n", __func__);
+    model.backend = ggml_backend_metal_init();
+    if (!model.backend) {
+        fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
+    }
+#endif
+
+    // if there aren't GPU Backends fallback to CPU backend
+    if (!model.backend) {
+        model.backend = ggml_backend_cpu_init();
+    }
+    struct ggml_context * tmp_ctx = nullptr;
+    struct gguf_init_params gguf_params = {
         /*.no_alloc   =*/ false,
-        /*.ctx        =*/ &model.ctx,
+        /*.ctx        =*/ &tmp_ctx,
     };
-    gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
-    if (!ctx) {
+    gguf_context * gguf_ctx = gguf_init_from_file(fname.c_str(), gguf_params);
+    if (!gguf_ctx) {
         fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
         return false;
     }
+
+    int num_tensors = gguf_get_n_tensors(gguf_ctx);
+    struct ggml_init_params params {
+            /*.mem_size   =*/ ggml_tensor_overhead() * num_tensors,
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ true,
+    };
+    model.ctx = ggml_init(params);
+    for (int i = 0; i < num_tensors; i++) {
+        const char * name = gguf_get_tensor_name(gguf_ctx, i);
+        struct ggml_tensor * src = ggml_get_tensor(tmp_ctx, name);
+        struct ggml_tensor * dst = ggml_dup_tensor(model.ctx, src);
+        ggml_set_name(dst, name);
+    }
+    model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx, model.backend);
+    // copy tensors from main memory to backend
+    for (struct ggml_tensor * cur = ggml_get_first_tensor(model.ctx); cur != NULL; cur = ggml_get_next_tensor(model.ctx, cur)) {
+        struct ggml_tensor * src = ggml_get_tensor(tmp_ctx, ggml_get_name(cur));
+        size_t n_size = ggml_nbytes(src);
+        ggml_backend_tensor_set(cur, ggml_get_data(src), 0, n_size);
+    }
+    gguf_free(gguf_ctx);
+
     model.width  = 416;
     model.height = 416;
     model.conv2d_layers.resize(13);
@@ -155,10 +218,10 @@ static void activate_array(float * x, const int n)
 
 static void apply_yolo(yolo_layer & layer)
 {
-    int w = layer.predictions->ne[0];
-    int h = layer.predictions->ne[1];
+    int w = layer.w;
+    int h = layer.h;
     int N = layer.mask.size();
-    float * data = ggml_get_data_f32(layer.predictions);
+    float * data = layer.predictions.data();
     for (int n = 0; n < N; n++) {
         int index = layer.entry_index(n*w*h, 0);
         activate_array(data + index, 2*w*h);
@@ -169,7 +232,7 @@ static void apply_yolo(yolo_layer & layer)
 
 static box get_yolo_box(const yolo_layer & layer, int n, int index, int i, int j, int lw, int lh, int w, int h, int stride)
 {
-    float * predictions = ggml_get_data_f32(layer.predictions);
+    const float * predictions = layer.predictions.data();
     box b;
     b.x = (i + predictions[index + 0*stride]) / lw;
     b.y = (j + predictions[index + 1*stride]) / lh;
@@ -197,10 +260,10 @@ static void correct_yolo_box(box & b, int im_w, int im_h, int net_w, int net_h)
 
 static void get_yolo_detections(const yolo_layer & layer, std::vector<detection> & detections, int im_w, int im_h, int netw, int neth, float thresh)
 {
-    int w = layer.predictions->ne[0];
-    int h = layer.predictions->ne[1];
+    int w = layer.w;
+    int h = layer.h;
     int N = layer.mask.size();
-    float * predictions = ggml_get_data_f32(layer.predictions);
+    const float * predictions = layer.predictions.data();
     std::vector<detection> result;
     for (int i = 0; i < w*h; i++) {
         for (int n = 0; n < N; n++) {
@@ -353,88 +416,92 @@ static void print_shape(int layer, const ggml_tensor * t)
     printf("Layer %2d output shape:  %3d x %3d x %4d x %3d\n", layer, (int)t->ne[0], (int)t->ne[1], (int)t->ne[2], (int)t->ne[3]);
 }
 
-void detect(yolo_image & img, const yolo_model & model, float thresh, const std::vector<std::string> & labels, const std::vector<yolo_image> & alphabet)
-{
-    static size_t buf_size = 20000000 * sizeof(float) * 4;
-    static void * buf = malloc(buf_size);
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_size,
-        /*.mem_buffer =*/ buf,
-        /*.no_alloc   =*/ false,
-    };
-
-    struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
-    std::vector<detection> detections;
+static struct ggml_cgraph * build_graph(struct ggml_context * ctx_cgraph, const yolo_model & model) {
+    struct ggml_cgraph * gf = ggml_new_graph(ctx_cgraph);
 
-    yolo_image sized = letterbox_image(img, model.width, model.height);
-    struct ggml_tensor * input = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, model.width, model.height, 3, 1);
-    std::memcpy(input->data, sized.data.data(), ggml_nbytes(input));
+    struct ggml_tensor * input = ggml_new_tensor_4d(ctx_cgraph, GGML_TYPE_F32, model.width, model.height, 3, 1);
     ggml_set_name(input, "input");
-
-    struct ggml_tensor * result = apply_conv2d(ctx0, input, model.conv2d_layers[0]);
+    struct ggml_tensor * result = apply_conv2d(ctx_cgraph, input, model.conv2d_layers[0]);
     print_shape(0, result);
-    result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+    result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
     print_shape(1, result);
-    result = apply_conv2d(ctx0, result, model.conv2d_layers[1]);
+    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[1]);
     print_shape(2, result);
-    result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+    result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
     print_shape(3, result);
-    result = apply_conv2d(ctx0, result, model.conv2d_layers[2]);
+    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[2]);
     print_shape(4, result);
-    result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+    result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
     print_shape(5, result);
-    result = apply_conv2d(ctx0, result, model.conv2d_layers[3]);
+    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[3]);
     print_shape(6, result);
-    result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+    result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
     print_shape(7, result);
-    result = apply_conv2d(ctx0, result, model.conv2d_layers[4]);
+    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[4]);
     struct ggml_tensor * layer_8 = result;
     print_shape(8, result);
-    result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+    result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
     print_shape(9, result);
-    result = apply_conv2d(ctx0, result, model.conv2d_layers[5]);
+    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[5]);
     print_shape(10, result);
-    result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 1, 1, 0.5, 0.5);
+    result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 1, 1, 0.5, 0.5);
     print_shape(11, result);
-    result = apply_conv2d(ctx0, result, model.conv2d_layers[6]);
+    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[6]);
     print_shape(12, result);
-    result = apply_conv2d(ctx0, result, model.conv2d_layers[7]);
+    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[7]);
     struct ggml_tensor * layer_13 = result;
     print_shape(13, result);
-    result = apply_conv2d(ctx0, result, model.conv2d_layers[8]);
+    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[8]);
     print_shape(14, result);
-    result = apply_conv2d(ctx0, result, model.conv2d_layers[9]);
+    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[9]);
     struct ggml_tensor * layer_15 = result;
+    ggml_set_output(layer_15);
+    ggml_set_name(layer_15, "layer_15");
+
     print_shape(15, result);
-    result = apply_conv2d(ctx0, layer_13, model.conv2d_layers[10]);
+    result = apply_conv2d(ctx_cgraph, layer_13, model.conv2d_layers[10]);
     print_shape(18, result);
-    result = ggml_upscale(ctx0, result, 2);
+    result = ggml_upscale(ctx_cgraph, result, 2);
     print_shape(19, result);
-    result = ggml_concat(ctx0, result, layer_8, 2);
+    result = ggml_concat(ctx_cgraph, result, layer_8, 2);
     print_shape(20, result);
-    result = apply_conv2d(ctx0, result, model.conv2d_layers[11]);
+    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[11]);
     print_shape(21, result);
-    result = apply_conv2d(ctx0, result, model.conv2d_layers[12]);
+    result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[12]);
     struct ggml_tensor * layer_22 = result;
+    ggml_set_output(layer_22);
+    ggml_set_name(layer_22, "layer_22");
     print_shape(22, result);
 
     ggml_build_forward_expand(gf, layer_15);
     ggml_build_forward_expand(gf, layer_22);
-    ggml_graph_compute_with_ctx(ctx0, gf, 1);
+    return gf;
+}
+
+void detect(yolo_image & img, struct ggml_cgraph * gf, const yolo_model & model, float thresh, const std::vector<std::string> & labels, const std::vector<yolo_image> & alphabet)
+{
+    std::vector<detection> detections;
+    yolo_image sized = letterbox_image(img, model.width, model.height);
+    struct ggml_tensor * input = ggml_graph_get_tensor(gf, "input");
+    ggml_backend_tensor_set(input, sized.data.data(), 0, ggml_nbytes(input));
+
+    if (ggml_backend_graph_compute(model.backend, gf) != GGML_STATUS_SUCCESS) {
+        fprintf(stderr, "%s: ggml_backend_graph_compute() failed\n", __func__);
+        return;
+    }
 
+    struct ggml_tensor * layer_15 = ggml_graph_get_tensor(gf, "layer_15");
     yolo_layer yolo16{ 80, {3, 4, 5}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_15};
     apply_yolo(yolo16);
     get_yolo_detections(yolo16, detections, img.w, img.h, model.width, model.height, thresh);
 
+    struct ggml_tensor * layer_22 = ggml_graph_get_tensor(gf, "layer_22");
     yolo_layer yolo23{ 80, {0, 1, 2}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_22};
     apply_yolo(yolo23);
     get_yolo_detections(yolo23, detections, img.w, img.h, model.width, model.height, thresh);
 
     do_nms_sort(detections, yolo23.classes, .45);
     draw_detections(img, detections, thresh, labels, alphabet);
-    ggml_free(ctx0);
 }
 
 struct yolo_params {
@@ -512,14 +579,31 @@ int main(int argc, char *argv[])
         fprintf(stderr, "%s: failed to load alphabet\n", __func__);
         return 1;
     }
+
+    struct ggml_init_params params0 = {
+        /*.mem_size   =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
+    };
+    struct ggml_context * ctx_cgraph = ggml_init(params0);
+    struct ggml_cgraph * gf = build_graph(ctx_cgraph, model);
+
+    ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));
+    ggml_gallocr_alloc_graph(allocr, gf);
+
     const int64_t t_start_ms = ggml_time_ms();
-    detect(img, model, params.thresh, labels, alphabet);
+    detect(img, gf, model, params.thresh, labels, alphabet);
     const int64_t t_detect_ms = ggml_time_ms() - t_start_ms;
     if (!save_image(img, params.fname_out.c_str(), 80)) {
         fprintf(stderr, "%s: failed to save image to '%s'\n", __func__, params.fname_out.c_str());
         return 1;
     }
     printf("Detected objects saved in '%s' (time: %f sec.)\n", params.fname_out.c_str(), t_detect_ms / 1000.0f);
+
+    ggml_free(ctx_cgraph);
+    ggml_gallocr_free(allocr);
     ggml_free(model.ctx);
+    ggml_backend_buffer_free(model.buffer);
+    ggml_backend_free(model.backend);
     return 0;
 }
index 8ff154f729ed612f8bc6241fe09348a3936dfddd..56c16a3c461ca90d7e49f14dca66d7470c898b7e 100644 (file)
@@ -2181,6 +2181,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_ADD:
             ggml_cuda_op_add(ctx, dst);
             break;
+        case GGML_OP_SUB:
+            ggml_cuda_op_sub(ctx, dst);
+            break;
         case GGML_OP_ACC:
             ggml_cuda_op_acc(ctx, dst);
             break;
@@ -2859,6 +2862,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_TRANSPOSE:
         case GGML_OP_NORM:
         case GGML_OP_ADD:
+        case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_RMS_NORM:
index 34bc67acdd890c077ae15de763435bab09ff0f2c..e1390a0414559fca9bc0a7ae05fd394de3c65615 100644 (file)
@@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) {
     return a + b;
 }
 
+static __device__ __forceinline__ float op_sub(const float a, const float b) {
+    return a - b;
+}
+
 static __device__ __forceinline__ float op_mul(const float a, const float b) {
     return a * b;
 }
@@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
 }
 
+void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
 void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
 }
index 4f63d6372eb50e717f47f7fc4844d90719119f82..198c9ef6fd8ea73c3e9e85f5ef0e60676365ca91 100644 (file)
@@ -2,5 +2,6 @@
 
 void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index f6bd6e3407e54f093218e4eedcb51a8585842860..7950c0dccb7f301c8f923bbaa16b1826102f7ba9 100644 (file)
@@ -31,6 +31,8 @@ struct ggml_metal_kernel {
 enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_ADD,
     GGML_METAL_KERNEL_TYPE_ADD_ROW,
+    GGML_METAL_KERNEL_TYPE_SUB,
+    GGML_METAL_KERNEL_TYPE_SUB_ROW,
     GGML_METAL_KERNEL_TYPE_MUL,
     GGML_METAL_KERNEL_TYPE_MUL_ROW,
     GGML_METAL_KERNEL_TYPE_DIV,
@@ -205,6 +207,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_CONCAT,
     GGML_METAL_KERNEL_TYPE_SQR,
+    GGML_METAL_KERNEL_TYPE_SQRT,
     GGML_METAL_KERNEL_TYPE_SIN,
     GGML_METAL_KERNEL_TYPE_COS,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -493,6 +496,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
 
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                           add,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                       add_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB,                           sub,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW,                       sub_row,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                           mul,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                       mul_row,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                           div,                            true);
@@ -667,6 +672,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
@@ -769,6 +775,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
         case GGML_OP_PERMUTE:
         case GGML_OP_CONCAT:
         case GGML_OP_ADD:
+        case GGML_OP_SUB:
         case GGML_OP_ACC:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
@@ -777,6 +784,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
         case GGML_OP_CLAMP:
             return true;
         case GGML_OP_SQR:
+        case GGML_OP_SQRT:
         case GGML_OP_SIN:
         case GGML_OP_COS:
             return ggml_is_contiguous(op->src[0]);
@@ -1057,6 +1065,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                     } break;
                 case GGML_OP_ADD:
+                case GGML_OP_SUB:
                 case GGML_OP_MUL:
                 case GGML_OP_DIV:
                     {
@@ -1080,6 +1089,7 @@ static enum ggml_status ggml_metal_graph_compute(
                             nb = ne00 / 4;
                             switch (dst->op) {
                                 case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
+                                case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
                                 case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
                                 case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
                                 default: GGML_ABORT("fatal error");
@@ -1089,6 +1099,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         } else {
                             switch (dst->op) {
                                 case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
+                                case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
                                 case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
                                 case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
                                 default: GGML_ABORT("fatal error");
@@ -1416,6 +1427,20 @@ static enum ggml_status ggml_metal_graph_compute(
 
                         const int64_t n = ggml_nelements(dst);
 
+                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    } break;
+                case GGML_OP_SQRT:
+                    {
+                        GGML_ASSERT(ggml_is_contiguous(src0));
+
+                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                        [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                        const int64_t n = ggml_nelements(dst);
+
                         [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                     } break;
                 case GGML_OP_SIN:
index 3e4b685bb51c4aa43cf79744bbbddd82752138a1..17432085c03ad7e6b9e93554665b277aa9c9c133 100644 (file)
@@ -17,7 +17,7 @@ enum ggml_sort_order {
     GGML_SORT_ORDER_DESC,
 };
 
-// general-purpose kernel for addition, multiplication and division of two tensors
+// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
 // pros: works for non-contiguous tensors, supports broadcast across all dims
 // cons: not very efficient
 kernel void kernel_add(
@@ -70,6 +70,56 @@ kernel void kernel_add(
     }
 }
 
+kernel void kernel_sub(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        constant  int64_t & offs,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
+
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
+    }
+}
+
 kernel void kernel_mul(
         device const char * src0,
         device const char * src1,
@@ -226,6 +276,15 @@ kernel void kernel_add_row(
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
 
+kernel void kernel_sub_row(
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant   uint64_t & nb [[buffer(28)]],
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] - src1[tpig % nb];
+}
+
 kernel void kernel_mul_row(
         device const float4 * src0,
         device const float4 * src1,
@@ -358,6 +417,13 @@ kernel void kernel_sqr(
     dst[tpig] = src0[tpig] * src0[tpig];
 }
 
+kernel void kernel_sqrt(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = sqrt(src0[tpig]);
+}
+
 kernel void kernel_sin(
         device const float * src0,
         device       float * dst,