#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+
+#ifdef GGML_USE_CUDA
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
#include "yolo-image.h"
#include <cmath>
int width = 416;
int height = 416;
std::vector<conv2d_layer> conv2d_layers;
+ ggml_backend_t backend = NULL;
+ ggml_backend_buffer_t buffer;
struct ggml_context * ctx;
};
int classes = 80;
std::vector<int> mask;
std::vector<float> anchors;
- struct ggml_tensor * predictions;
-
- yolo_layer(int classes, const std::vector<int> & mask, const std::vector<float> & anchors, struct ggml_tensor * predictions)
- : classes(classes), mask(mask), anchors(anchors), predictions(predictions)
- { }
+ std::vector<float> predictions;
+ int w;
+ int h;
+
+ yolo_layer(int classes, const std::vector<int> & mask, const std::vector<float> & anchors, struct ggml_tensor * prev_layer)
+ : classes(classes), mask(mask), anchors(anchors)
+ {
+ w = prev_layer->ne[0];
+ h = prev_layer->ne[1];
+ predictions.resize(ggml_nbytes(prev_layer)/sizeof(float));
+ ggml_backend_tensor_get(prev_layer, predictions.data(), 0, ggml_nbytes(prev_layer));
+ }
int entry_index(int location, int entry) const {
- int w = predictions->ne[0];
- int h = predictions->ne[1];
int n = location / (w*h);
int loc = location % (w*h);
return n*w*h*(4+classes+1) + entry*w*h + loc;
};
static bool load_model(const std::string & fname, yolo_model & model) {
- struct gguf_init_params params = {
+ // initialize the backend
+#ifdef GGML_USE_CUDA
+ fprintf(stderr, "%s: using CUDA backend\n", __func__);
+ model.backend = ggml_backend_cuda_init(0); // init device 0
+ if (!model.backend) {
+ fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+ }
+#endif
+
+#ifdef GGML_USE_METAL
+ fprintf(stderr, "%s: using Metal backend\n", __func__);
+ model.backend = ggml_backend_metal_init();
+ if (!model.backend) {
+ fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
+ }
+#endif
+
+ // if there aren't GPU Backends fallback to CPU backend
+ if (!model.backend) {
+ model.backend = ggml_backend_cpu_init();
+ }
+ struct ggml_context * tmp_ctx = nullptr;
+ struct gguf_init_params gguf_params = {
/*.no_alloc =*/ false,
- /*.ctx =*/ &model.ctx,
+ /*.ctx =*/ &tmp_ctx,
};
- gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
- if (!ctx) {
+ gguf_context * gguf_ctx = gguf_init_from_file(fname.c_str(), gguf_params);
+ if (!gguf_ctx) {
fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
return false;
}
+
+ int num_tensors = gguf_get_n_tensors(gguf_ctx);
+ struct ggml_init_params params {
+ /*.mem_size =*/ ggml_tensor_overhead() * num_tensors,
+ /*.mem_buffer =*/ NULL,
+ /*.no_alloc =*/ true,
+ };
+ model.ctx = ggml_init(params);
+ for (int i = 0; i < num_tensors; i++) {
+ const char * name = gguf_get_tensor_name(gguf_ctx, i);
+ struct ggml_tensor * src = ggml_get_tensor(tmp_ctx, name);
+ struct ggml_tensor * dst = ggml_dup_tensor(model.ctx, src);
+ ggml_set_name(dst, name);
+ }
+ model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx, model.backend);
+ // copy tensors from main memory to backend
+ for (struct ggml_tensor * cur = ggml_get_first_tensor(model.ctx); cur != NULL; cur = ggml_get_next_tensor(model.ctx, cur)) {
+ struct ggml_tensor * src = ggml_get_tensor(tmp_ctx, ggml_get_name(cur));
+ size_t n_size = ggml_nbytes(src);
+ ggml_backend_tensor_set(cur, ggml_get_data(src), 0, n_size);
+ }
+ gguf_free(gguf_ctx);
+
model.width = 416;
model.height = 416;
model.conv2d_layers.resize(13);
static void apply_yolo(yolo_layer & layer)
{
- int w = layer.predictions->ne[0];
- int h = layer.predictions->ne[1];
+ int w = layer.w;
+ int h = layer.h;
int N = layer.mask.size();
- float * data = ggml_get_data_f32(layer.predictions);
+ float * data = layer.predictions.data();
for (int n = 0; n < N; n++) {
int index = layer.entry_index(n*w*h, 0);
activate_array(data + index, 2*w*h);
static box get_yolo_box(const yolo_layer & layer, int n, int index, int i, int j, int lw, int lh, int w, int h, int stride)
{
- float * predictions = ggml_get_data_f32(layer.predictions);
+ const float * predictions = layer.predictions.data();
box b;
b.x = (i + predictions[index + 0*stride]) / lw;
b.y = (j + predictions[index + 1*stride]) / lh;
static void get_yolo_detections(const yolo_layer & layer, std::vector<detection> & detections, int im_w, int im_h, int netw, int neth, float thresh)
{
- int w = layer.predictions->ne[0];
- int h = layer.predictions->ne[1];
+ int w = layer.w;
+ int h = layer.h;
int N = layer.mask.size();
- float * predictions = ggml_get_data_f32(layer.predictions);
+ const float * predictions = layer.predictions.data();
std::vector<detection> result;
for (int i = 0; i < w*h; i++) {
for (int n = 0; n < N; n++) {
printf("Layer %2d output shape: %3d x %3d x %4d x %3d\n", layer, (int)t->ne[0], (int)t->ne[1], (int)t->ne[2], (int)t->ne[3]);
}
-void detect(yolo_image & img, const yolo_model & model, float thresh, const std::vector<std::string> & labels, const std::vector<yolo_image> & alphabet)
-{
- static size_t buf_size = 20000000 * sizeof(float) * 4;
- static void * buf = malloc(buf_size);
-
- struct ggml_init_params params = {
- /*.mem_size =*/ buf_size,
- /*.mem_buffer =*/ buf,
- /*.no_alloc =*/ false,
- };
-
- struct ggml_context * ctx0 = ggml_init(params);
- struct ggml_cgraph * gf = ggml_new_graph(ctx0);
- std::vector<detection> detections;
+static struct ggml_cgraph * build_graph(struct ggml_context * ctx_cgraph, const yolo_model & model) {
+ struct ggml_cgraph * gf = ggml_new_graph(ctx_cgraph);
- yolo_image sized = letterbox_image(img, model.width, model.height);
- struct ggml_tensor * input = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, model.width, model.height, 3, 1);
- std::memcpy(input->data, sized.data.data(), ggml_nbytes(input));
+ struct ggml_tensor * input = ggml_new_tensor_4d(ctx_cgraph, GGML_TYPE_F32, model.width, model.height, 3, 1);
ggml_set_name(input, "input");
-
- struct ggml_tensor * result = apply_conv2d(ctx0, input, model.conv2d_layers[0]);
+ struct ggml_tensor * result = apply_conv2d(ctx_cgraph, input, model.conv2d_layers[0]);
print_shape(0, result);
- result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+ result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(1, result);
- result = apply_conv2d(ctx0, result, model.conv2d_layers[1]);
+ result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[1]);
print_shape(2, result);
- result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+ result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(3, result);
- result = apply_conv2d(ctx0, result, model.conv2d_layers[2]);
+ result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[2]);
print_shape(4, result);
- result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+ result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(5, result);
- result = apply_conv2d(ctx0, result, model.conv2d_layers[3]);
+ result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[3]);
print_shape(6, result);
- result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+ result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(7, result);
- result = apply_conv2d(ctx0, result, model.conv2d_layers[4]);
+ result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[4]);
struct ggml_tensor * layer_8 = result;
print_shape(8, result);
- result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+ result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(9, result);
- result = apply_conv2d(ctx0, result, model.conv2d_layers[5]);
+ result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[5]);
print_shape(10, result);
- result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 1, 1, 0.5, 0.5);
+ result = ggml_pool_2d(ctx_cgraph, result, GGML_OP_POOL_MAX, 2, 2, 1, 1, 0.5, 0.5);
print_shape(11, result);
- result = apply_conv2d(ctx0, result, model.conv2d_layers[6]);
+ result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[6]);
print_shape(12, result);
- result = apply_conv2d(ctx0, result, model.conv2d_layers[7]);
+ result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[7]);
struct ggml_tensor * layer_13 = result;
print_shape(13, result);
- result = apply_conv2d(ctx0, result, model.conv2d_layers[8]);
+ result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[8]);
print_shape(14, result);
- result = apply_conv2d(ctx0, result, model.conv2d_layers[9]);
+ result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[9]);
struct ggml_tensor * layer_15 = result;
+ ggml_set_output(layer_15);
+ ggml_set_name(layer_15, "layer_15");
+
print_shape(15, result);
- result = apply_conv2d(ctx0, layer_13, model.conv2d_layers[10]);
+ result = apply_conv2d(ctx_cgraph, layer_13, model.conv2d_layers[10]);
print_shape(18, result);
- result = ggml_upscale(ctx0, result, 2);
+ result = ggml_upscale(ctx_cgraph, result, 2);
print_shape(19, result);
- result = ggml_concat(ctx0, result, layer_8, 2);
+ result = ggml_concat(ctx_cgraph, result, layer_8, 2);
print_shape(20, result);
- result = apply_conv2d(ctx0, result, model.conv2d_layers[11]);
+ result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[11]);
print_shape(21, result);
- result = apply_conv2d(ctx0, result, model.conv2d_layers[12]);
+ result = apply_conv2d(ctx_cgraph, result, model.conv2d_layers[12]);
struct ggml_tensor * layer_22 = result;
+ ggml_set_output(layer_22);
+ ggml_set_name(layer_22, "layer_22");
print_shape(22, result);
ggml_build_forward_expand(gf, layer_15);
ggml_build_forward_expand(gf, layer_22);
- ggml_graph_compute_with_ctx(ctx0, gf, 1);
+ return gf;
+}
+
+void detect(yolo_image & img, struct ggml_cgraph * gf, const yolo_model & model, float thresh, const std::vector<std::string> & labels, const std::vector<yolo_image> & alphabet)
+{
+ std::vector<detection> detections;
+ yolo_image sized = letterbox_image(img, model.width, model.height);
+ struct ggml_tensor * input = ggml_graph_get_tensor(gf, "input");
+ ggml_backend_tensor_set(input, sized.data.data(), 0, ggml_nbytes(input));
+
+ if (ggml_backend_graph_compute(model.backend, gf) != GGML_STATUS_SUCCESS) {
+ fprintf(stderr, "%s: ggml_backend_graph_compute() failed\n", __func__);
+ return;
+ }
+ struct ggml_tensor * layer_15 = ggml_graph_get_tensor(gf, "layer_15");
yolo_layer yolo16{ 80, {3, 4, 5}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_15};
apply_yolo(yolo16);
get_yolo_detections(yolo16, detections, img.w, img.h, model.width, model.height, thresh);
+ struct ggml_tensor * layer_22 = ggml_graph_get_tensor(gf, "layer_22");
yolo_layer yolo23{ 80, {0, 1, 2}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_22};
apply_yolo(yolo23);
get_yolo_detections(yolo23, detections, img.w, img.h, model.width, model.height, thresh);
do_nms_sort(detections, yolo23.classes, .45);
draw_detections(img, detections, thresh, labels, alphabet);
- ggml_free(ctx0);
}
struct yolo_params {
fprintf(stderr, "%s: failed to load alphabet\n", __func__);
return 1;
}
+
+ struct ggml_init_params params0 = {
+ /*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
+ /*.mem_buffer =*/ NULL,
+ /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
+ };
+ struct ggml_context * ctx_cgraph = ggml_init(params0);
+ struct ggml_cgraph * gf = build_graph(ctx_cgraph, model);
+
+ ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));
+ ggml_gallocr_alloc_graph(allocr, gf);
+
const int64_t t_start_ms = ggml_time_ms();
- detect(img, model, params.thresh, labels, alphabet);
+ detect(img, gf, model, params.thresh, labels, alphabet);
const int64_t t_detect_ms = ggml_time_ms() - t_start_ms;
if (!save_image(img, params.fname_out.c_str(), 80)) {
fprintf(stderr, "%s: failed to save image to '%s'\n", __func__, params.fname_out.c_str());
return 1;
}
printf("Detected objects saved in '%s' (time: %f sec.)\n", params.fname_out.c_str(), t_detect_ms / 1000.0f);
+
+ ggml_free(ctx_cgraph);
+ ggml_gallocr_free(allocr);
ggml_free(model.ctx);
+ ggml_backend_buffer_free(model.buffer);
+ ggml_backend_free(model.backend);
return 0;
}
enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ADD,
GGML_METAL_KERNEL_TYPE_ADD_ROW,
+ GGML_METAL_KERNEL_TYPE_SUB,
+ GGML_METAL_KERNEL_TYPE_SUB_ROW,
GGML_METAL_KERNEL_TYPE_MUL,
GGML_METAL_KERNEL_TYPE_MUL_ROW,
GGML_METAL_KERNEL_TYPE_DIV,
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
GGML_METAL_KERNEL_TYPE_CONCAT,
GGML_METAL_KERNEL_TYPE_SQR,
+ GGML_METAL_KERNEL_TYPE_SQRT,
GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
case GGML_OP_PERMUTE:
case GGML_OP_CONCAT:
case GGML_OP_ADD:
+ case GGML_OP_SUB:
case GGML_OP_ACC:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_CLAMP:
return true;
case GGML_OP_SQR:
+ case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
return ggml_is_contiguous(op->src[0]);
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ADD:
+ case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
{
nb = ne00 / 4;
switch (dst->op) {
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
+ case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
default: GGML_ABORT("fatal error");
} else {
switch (dst->op) {
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
+ case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
default: GGML_ABORT("fatal error");
const int64_t n = ggml_nelements(dst);
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_SQRT:
+ {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SIN:
GGML_SORT_ORDER_DESC,
};
-// general-purpose kernel for addition, multiplication and division of two tensors
+// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
// pros: works for non-contiguous tensors, supports broadcast across all dims
// cons: not very efficient
kernel void kernel_add(
}
}
+kernel void kernel_sub(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int64_t & offs,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
+ }
+}
+
kernel void kernel_mul(
device const char * src0,
device const char * src1,
dst[tpig] = src0[tpig] + src1[tpig % nb];
}
+kernel void kernel_sub_row(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant uint64_t & nb [[buffer(28)]],
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] - src1[tpig % nb];
+}
+
kernel void kernel_mul_row(
device const float4 * src0,
device const float4 * src1,
dst[tpig] = src0[tpig] * src0[tpig];
}
+kernel void kernel_sqrt(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = sqrt(src0[tpig]);
+}
+
kernel void kernel_sin(
device const float * src0,
device float * dst,