From: Radoslav Gerganov Date: Mon, 19 Aug 2024 07:09:33 +0000 (+0300) Subject: yolo : add backend support (#924) X-Git-Tag: upstream/0.0.1642~443 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=46e22f59eaf0aaa38a8e525fd89ba95e39ba7435;p=pkg%2Fggml%2Fsources%2Fggml yolo : add backend support (#924) * yolo : add backend support * metal : add sub and sqrt kernels --------- Co-authored-by: Georgi Gerganov --- diff --git a/examples/yolo/README.md b/examples/yolo/README.md index 0e69dc99..d2ced38c 100644 --- a/examples/yolo/README.md +++ b/examples/yolo/README.md @@ -17,11 +17,18 @@ $ ./convert-yolov3-tiny.py yolov3-tiny.weights yolov3-tiny.weights converted to yolov3-tiny.gguf ``` +Alternatively, you can download the converted model from [HuggingFace](https://huggingface.co/rgerganov/yolo-gguf/resolve/main/yolov3-tiny.gguf) + Object detection: ```bash $ wget https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg $ ./yolov3-tiny -m yolov3-tiny.gguf -i dog.jpg +load_model: using CUDA backend +ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no +ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no +ggml_cuda_init: found 1 CUDA devices: + Device 0: NVIDIA T1200 Laptop GPU, compute capability 7.5, VMM: yes Layer 0 output shape: 416 x 416 x 16 x 1 Layer 1 output shape: 208 x 208 x 16 x 1 Layer 2 output shape: 208 x 208 x 32 x 1 @@ -48,5 +55,5 @@ car: 52% truck: 56% car: 62% bicycle: 59% -Detected objects saved in 'predictions.jpg' (time: 0.357000 sec.) +Detected objects saved in 'predictions.jpg' (time: 0.057000 sec.) ``` \ No newline at end of file diff --git a/examples/yolo/yolov3-tiny.cpp b/examples/yolo/yolov3-tiny.cpp index ae7c2a00..369e9efb 100644 --- a/examples/yolo/yolov3-tiny.cpp +++ b/examples/yolo/yolov3-tiny.cpp @@ -1,4 +1,15 @@ #include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + #include "yolo-image.h" #include @@ -29,6 +40,8 @@ struct yolo_model { int width = 416; int height = 416; std::vector conv2d_layers; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; struct ggml_context * ctx; }; @@ -36,15 +49,20 @@ struct yolo_layer { int classes = 80; std::vector mask; std::vector anchors; - struct ggml_tensor * predictions; - - yolo_layer(int classes, const std::vector & mask, const std::vector & anchors, struct ggml_tensor * predictions) - : classes(classes), mask(mask), anchors(anchors), predictions(predictions) - { } + std::vector predictions; + int w; + int h; + + yolo_layer(int classes, const std::vector & mask, const std::vector & anchors, struct ggml_tensor * prev_layer) + : classes(classes), mask(mask), anchors(anchors) + { + w = prev_layer->ne[0]; + h = prev_layer->ne[1]; + predictions.resize(ggml_nbytes(prev_layer)/sizeof(float)); + ggml_backend_tensor_get(prev_layer, predictions.data(), 0, ggml_nbytes(prev_layer)); + } int entry_index(int location, int entry) const { - int w = predictions->ne[0]; - int h = predictions->ne[1]; int n = location / (w*h); int loc = location % (w*h); return n*w*h*(4+classes+1) + entry*w*h + loc; @@ -62,15 +80,60 @@ struct detection { }; static bool load_model(const std::string & fname, yolo_model & model) { - struct gguf_init_params params = { + // initialize the backend +#ifdef GGML_USE_CUDA + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); // init device 0 + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } +#endif + +#ifdef GGML_USE_METAL + fprintf(stderr, "%s: using Metal backend\n", __func__); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } +#endif + + // if there aren't GPU Backends fallback to CPU backend + if (!model.backend) { + model.backend = ggml_backend_cpu_init(); + } + struct ggml_context * tmp_ctx = nullptr; + struct gguf_init_params gguf_params = { /*.no_alloc =*/ false, - /*.ctx =*/ &model.ctx, + /*.ctx =*/ &tmp_ctx, }; - gguf_context * ctx = gguf_init_from_file(fname.c_str(), params); - if (!ctx) { + gguf_context * gguf_ctx = gguf_init_from_file(fname.c_str(), gguf_params); + if (!gguf_ctx) { fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__); return false; } + + int num_tensors = gguf_get_n_tensors(gguf_ctx); + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + model.ctx = ggml_init(params); + for (int i = 0; i < num_tensors; i++) { + const char * name = gguf_get_tensor_name(gguf_ctx, i); + struct ggml_tensor * src = ggml_get_tensor(tmp_ctx, name); + struct ggml_tensor * dst = ggml_dup_tensor(model.ctx, src); + ggml_set_name(dst, name); + } + model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx, model.backend); + // copy tensors from main memory to backend + for (struct ggml_tensor * cur = ggml_get_first_tensor(model.ctx); cur != NULL; cur = ggml_get_next_tensor(model.ctx, cur)) { + struct ggml_tensor * src = ggml_get_tensor(tmp_ctx, ggml_get_name(cur)); + size_t n_size = ggml_nbytes(src); + ggml_backend_tensor_set(cur, ggml_get_data(src), 0, n_size); + } + gguf_free(gguf_ctx); + model.width = 416; model.height = 416; model.conv2d_layers.resize(13); @@ -155,10 +218,10 @@ static void activate_array(float * x, const int n) static void apply_yolo(yolo_layer & layer) { - int w = layer.predictions->ne[0]; - int h = layer.predictions->ne[1]; + int w = layer.w; + int h = layer.h; int N = layer.mask.size(); - float * data = ggml_get_data_f32(layer.predictions); + float * data = layer.predictions.data(); for (int n = 0; n < N; n++) { int index = layer.entry_index(n*w*h, 0); activate_array(data + index, 2*w*h); @@ -169,7 +232,7 @@ static void apply_yolo(yolo_layer & layer) static box get_yolo_box(const yolo_layer & layer, int n, int index, int i, int j, int lw, int lh, int w, int h, int stride) { - float * predictions = ggml_get_data_f32(layer.predictions); + const float * predictions = layer.predictions.data(); box b; b.x = (i + predictions[index + 0*stride]) / lw; b.y = (j + predictions[index + 1*stride]) / lh; @@ -197,10 +260,10 @@ static void correct_yolo_box(box & b, int im_w, int im_h, int net_w, int net_h) static void get_yolo_detections(const yolo_layer & layer, std::vector & detections, int im_w, int im_h, int netw, int neth, float thresh) { - int w = layer.predictions->ne[0]; - int h = layer.predictions->ne[1]; + int w = layer.w; + int h = layer.h; int N = layer.mask.size(); - float * predictions = ggml_get_data_f32(layer.predictions); + const float * predictions = layer.predictions.data(); std::vector result; for (int i = 0; i < w*h; i++) { for (int n = 0; n < N; n++) { @@ -353,88 +416,92 @@ static void print_shape(int layer, const ggml_tensor * t) printf("Layer %2d output shape: %3d x %3d x %4d x %3d\n", layer, (int)t->ne[0], (int)t->ne[1], (int)t->ne[2], (int)t->ne[3]); } -void detect(yolo_image & img, const yolo_model & model, float thresh, const std::vector & labels, const std::vector & alphabet) -{ - static size_t buf_size = 20000000 * sizeof(float) * 4; - static void * buf = malloc(buf_size); - - struct ggml_init_params params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf, - /*.no_alloc =*/ false, - }; - - struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - std::vector detections; +static struct ggml_cgraph * build_graph(struct ggml_context * ctx_cgraph, const yolo_model & model) { + struct ggml_cgraph * gf = ggml_new_graph(ctx_cgraph); - yolo_image sized = letterbox_image(img, model.width, model.height); - struct ggml_tensor * input = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, model.width, model.height, 3, 1); - std::memcpy(input->data, sized.data.data(), ggml_nbytes(input)); + struct ggml_tensor * input = ggml_new_tensor_4d(ctx_cgraph, GGML_TYPE_F32, model.width, model.height, 3, 1); ggml_set_name(input, "input"); - - struct ggml_tensor * result = apply_conv2d(ctx0, input, model.conv2d_layers[0]); + struct ggml_tensor * result = apply_conv2d(ctx_cgraph, input, model.conv2d_layers[0]); print_shape(0, result); - result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); + result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); print_shape(1, result); - result = apply_conv2d(ctx0, result, model.conv2d_layers[1]); + result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[1]); print_shape(2, result); - result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); + result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); print_shape(3, result); - result = apply_conv2d(ctx0, result, model.conv2d_layers[2]); + result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[2]); print_shape(4, result); - result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); + result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); print_shape(5, result); - result = apply_conv2d(ctx0, result, model.conv2d_layers[3]); + result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[3]); print_shape(6, result); - result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); + result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); print_shape(7, result); - result = apply_conv2d(ctx0, result, model.conv2d_layers[4]); + result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[4]); struct ggml_tensor * layer_8 = result; print_shape(8, result); - result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); + result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); print_shape(9, result); - result = apply_conv2d(ctx0, result, model.conv2d_layers[5]); + result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[5]); print_shape(10, result); - result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 1, 1, 0.5, 0.5); + result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 1, 1, 0.5, 0.5); print_shape(11, result); - result = apply_conv2d(ctx0, result, model.conv2d_layers[6]); + result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[6]); print_shape(12, result); - result = apply_conv2d(ctx0, result, model.conv2d_layers[7]); + result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[7]); struct ggml_tensor * layer_13 = result; print_shape(13, result); - result = apply_conv2d(ctx0, result, model.conv2d_layers[8]); + result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[8]); print_shape(14, result); - result = apply_conv2d(ctx0, result, model.conv2d_layers[9]); + result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[9]); struct ggml_tensor * layer_15 = result; + ggml_set_output(layer_15); + ggml_set_name(layer_15, "layer_15"); + print_shape(15, result); - result = apply_conv2d(ctx0, layer_13, model.conv2d_layers[10]); + result = apply_conv2d(ctx_cgraph, layer_13, model.conv2d_layers[10]); print_shape(18, result); - result = ggml_upscale(ctx0, result, 2); + result = ggml_upscale(ctx_cgraph, result, 2); print_shape(19, result); - result = ggml_concat(ctx0, result, layer_8, 2); + result = ggml_concat(ctx_cgraph, result, layer_8, 2); print_shape(20, result); - result = apply_conv2d(ctx0, result, model.conv2d_layers[11]); + result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[11]); print_shape(21, result); - result = apply_conv2d(ctx0, result, model.conv2d_layers[12]); + result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[12]); struct ggml_tensor * layer_22 = result; + ggml_set_output(layer_22); + ggml_set_name(layer_22, "layer_22"); print_shape(22, result); ggml_build_forward_expand(gf, layer_15); ggml_build_forward_expand(gf, layer_22); - ggml_graph_compute_with_ctx(ctx0, gf, 1); + return gf; +} + +void detect(yolo_image & img, struct ggml_cgraph * gf, const yolo_model & model, float thresh, const std::vector & labels, const std::vector & alphabet) +{ + std::vector detections; + yolo_image sized = letterbox_image(img, model.width, model.height); + struct ggml_tensor * input = ggml_graph_get_tensor(gf, "input"); + ggml_backend_tensor_set(input, sized.data.data(), 0, ggml_nbytes(input)); + + if (ggml_backend_graph_compute(model.backend, gf) != GGML_STATUS_SUCCESS) { + fprintf(stderr, "%s: ggml_backend_graph_compute() failed\n", __func__); + return; + } + struct ggml_tensor * layer_15 = ggml_graph_get_tensor(gf, "layer_15"); yolo_layer yolo16{ 80, {3, 4, 5}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_15}; apply_yolo(yolo16); get_yolo_detections(yolo16, detections, img.w, img.h, model.width, model.height, thresh); + struct ggml_tensor * layer_22 = ggml_graph_get_tensor(gf, "layer_22"); yolo_layer yolo23{ 80, {0, 1, 2}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_22}; apply_yolo(yolo23); get_yolo_detections(yolo23, detections, img.w, img.h, model.width, model.height, thresh); do_nms_sort(detections, yolo23.classes, .45); draw_detections(img, detections, thresh, labels, alphabet); - ggml_free(ctx0); } struct yolo_params { @@ -512,14 +579,31 @@ int main(int argc, char *argv[]) fprintf(stderr, "%s: failed to load alphabet\n", __func__); return 1; } + + struct ggml_init_params params0 = { + /*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + struct ggml_context * ctx_cgraph = ggml_init(params0); + struct ggml_cgraph * gf = build_graph(ctx_cgraph, model); + + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + ggml_gallocr_alloc_graph(allocr, gf); + const int64_t t_start_ms = ggml_time_ms(); - detect(img, model, params.thresh, labels, alphabet); + detect(img, gf, model, params.thresh, labels, alphabet); const int64_t t_detect_ms = ggml_time_ms() - t_start_ms; if (!save_image(img, params.fname_out.c_str(), 80)) { fprintf(stderr, "%s: failed to save image to '%s'\n", __func__, params.fname_out.c_str()); return 1; } printf("Detected objects saved in '%s' (time: %f sec.)\n", params.fname_out.c_str(), t_detect_ms / 1000.0f); + + ggml_free(ctx_cgraph); + ggml_gallocr_free(allocr); ggml_free(model.ctx); + ggml_backend_buffer_free(model.buffer); + ggml_backend_free(model.backend); return 0; } diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index 8ff154f7..56c16a3c 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -2181,6 +2181,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ADD: ggml_cuda_op_add(ctx, dst); break; + case GGML_OP_SUB: + ggml_cuda_op_sub(ctx, dst); + break; case GGML_OP_ACC: ggml_cuda_op_acc(ctx, dst); break; @@ -2859,6 +2862,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: diff --git a/src/ggml-cuda/binbcast.cu b/src/ggml-cuda/binbcast.cu index 34bc67ac..e1390a04 100644 --- a/src/ggml-cuda/binbcast.cu +++ b/src/ggml-cuda/binbcast.cu @@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) { return a + b; } +static __device__ __forceinline__ float op_sub(const float a, const float b) { + return a - b; +} + static __device__ __forceinline__ float op_mul(const float a, const float b) { return a * b; } @@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } +void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); +} + void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } diff --git a/src/ggml-cuda/binbcast.cuh b/src/ggml-cuda/binbcast.cuh index 4f63d637..198c9ef6 100644 --- a/src/ggml-cuda/binbcast.cuh +++ b/src/ggml-cuda/binbcast.cuh @@ -2,5 +2,6 @@ void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml-metal.m b/src/ggml-metal.m index f6bd6e34..7950c0dc 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -31,6 +31,8 @@ struct ggml_metal_kernel { enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ADD, GGML_METAL_KERNEL_TYPE_ADD_ROW, + GGML_METAL_KERNEL_TYPE_SUB, + GGML_METAL_KERNEL_TYPE_SUB_ROW, GGML_METAL_KERNEL_TYPE_MUL, GGML_METAL_KERNEL_TYPE_MUL_ROW, GGML_METAL_KERNEL_TYPE_DIV, @@ -205,6 +207,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_SQR, + GGML_METAL_KERNEL_TYPE_SQRT, GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_SUM_ROWS, @@ -493,6 +496,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); @@ -667,6 +672,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); @@ -769,6 +775,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_PERMUTE: case GGML_OP_CONCAT: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: @@ -777,6 +784,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_CLAMP: return true; case GGML_OP_SQR: + case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: return ggml_is_contiguous(op->src[0]); @@ -1057,6 +1065,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: { @@ -1080,6 +1089,7 @@ static enum ggml_status ggml_metal_graph_compute( nb = ne00 / 4; switch (dst->op) { case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; default: GGML_ABORT("fatal error"); @@ -1089,6 +1099,7 @@ static enum ggml_status ggml_metal_graph_compute( } else { switch (dst->op) { case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; default: GGML_ABORT("fatal error"); @@ -1416,6 +1427,20 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SQRT: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_SIN: diff --git a/src/ggml-metal.metal b/src/ggml-metal.metal index 3e4b685b..17432085 100644 --- a/src/ggml-metal.metal +++ b/src/ggml-metal.metal @@ -17,7 +17,7 @@ enum ggml_sort_order { GGML_SORT_ORDER_DESC, }; -// general-purpose kernel for addition, multiplication and division of two tensors +// general-purpose kernel for addition, subtraction, multiplication and division of two tensors // pros: works for non-contiguous tensors, supports broadcast across all dims // cons: not very efficient kernel void kernel_add( @@ -70,6 +70,56 @@ kernel void kernel_add( } } +kernel void kernel_sub( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int64_t & offs, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10)); + } +} + kernel void kernel_mul( device const char * src0, device const char * src1, @@ -226,6 +276,15 @@ kernel void kernel_add_row( dst[tpig] = src0[tpig] + src1[tpig % nb]; } +kernel void kernel_sub_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] - src1[tpig % nb]; +} + kernel void kernel_mul_row( device const float4 * src0, device const float4 * src1, @@ -358,6 +417,13 @@ kernel void kernel_sqr( dst[tpig] = src0[tpig] * src0[tpig]; } +kernel void kernel_sqrt( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sqrt(src0[tpig]); +} + kernel void kernel_sin( device const float * src0, device float * dst,