- [X] Example of ChatGLM inference [li-plus/chatglm.cpp](https://github.com/li-plus/chatglm.cpp)
- [X] Example of Stable Diffusion inference [leejet/stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp)
- [X] Example of Qwen inference [QwenLM/qwen.cpp](https://github.com/QwenLM/qwen.cpp)
+- [X] Example of YOLO inference [examples/yolo](https://github.com/ggerganov/ggml/tree/master/examples/yolo)
## Whisper inference (example)
gg_printf '```\n'
}
+# yolo
+
+function gg_run_yolo {
+ cd ${SRC}
+
+ gg_wget models-mnt/yolo/ https://pjreddie.com/media/files/yolov3-tiny.weights
+ gg_wget models-mnt/yolo/ https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg
+
+ cd build-ci-release
+ cp -r ../examples/yolo/data .
+
+ set -e
+
+ path_models="../models-mnt/yolo/"
+
+ python3 ../examples/yolo/convert-yolov3-tiny.py ${path_models}/yolov3-tiny.weights
+
+ (time ./bin/yolov3-tiny -m yolov3-tiny.gguf -i ${path_models}/dog.jpg ) 2>&1 | tee -a $OUT/${ci}-main.log
+
+ grep -q "dog: 57%" $OUT/${ci}-main.log
+ grep -q "car: 52%" $OUT/${ci}-main.log
+ grep -q "truck: 56%" $OUT/${ci}-main.log
+ grep -q "bicycle: 59%" $OUT/${ci}-main.log
+
+ set +e
+}
+
+function gg_sum_yolo {
+ gg_printf '### %s\n\n' "${ci}"
+
+ gg_printf 'Run YOLO\n'
+ gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
+ gg_printf '```\n'
+ gg_printf '%s\n' "$(cat $OUT/${ci}-main.log)"
+ gg_printf '```\n'
+}
+
# mpt
function gg_run_mpt {
test $ret -eq 0 && gg_run mnist
test $ret -eq 0 && gg_run whisper
test $ret -eq 0 && gg_run sam
+test $ret -eq 0 && gg_run yolo
if [ -z $GG_BUILD_LOW_PERF ]; then
if [ -z ${GG_BUILD_VRAM_GB} ] || [ ${GG_BUILD_VRAM_GB} -ge 16 ]; then
add_subdirectory(mpt)
add_subdirectory(starcoder)
add_subdirectory(sam)
+add_subdirectory(yolo)
--- /dev/null
+#
+# yolov3-tiny
+
+set(TEST_TARGET yolov3-tiny)
+add_executable(${TEST_TARGET} yolov3-tiny.cpp yolo-image.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common)
--- /dev/null
+This example shows how to implement YOLO object detection with ggml using pretrained model.
+
+# YOLOv3-tiny
+
+Download the model weights:
+
+```bash
+$ wget https://pjreddie.com/media/files/yolov3-tiny.weights
+$ sha1sum yolov3-tiny.weights
+40f3c11883bef62fd850213bc14266632ed4414f yolov3-tiny.weights
+```
+
+Convert the weights to GGUF format:
+
+```bash
+$ ./convert-yolov3-tiny.py yolov3-tiny.weights
+yolov3-tiny.weights converted to yolov3-tiny.gguf
+```
+
+Object detection:
+
+```bash
+$ wget https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg
+$ ./yolov3-tiny -m yolov3-tiny.gguf -i dog.jpg
+Layer 0 output shape: 416 x 416 x 16 x 1
+Layer 1 output shape: 208 x 208 x 16 x 1
+Layer 2 output shape: 208 x 208 x 32 x 1
+Layer 3 output shape: 104 x 104 x 32 x 1
+Layer 4 output shape: 104 x 104 x 64 x 1
+Layer 5 output shape: 52 x 52 x 64 x 1
+Layer 6 output shape: 52 x 52 x 128 x 1
+Layer 7 output shape: 26 x 26 x 128 x 1
+Layer 8 output shape: 26 x 26 x 256 x 1
+Layer 9 output shape: 13 x 13 x 256 x 1
+Layer 10 output shape: 13 x 13 x 512 x 1
+Layer 11 output shape: 13 x 13 x 512 x 1
+Layer 12 output shape: 13 x 13 x 1024 x 1
+Layer 13 output shape: 13 x 13 x 256 x 1
+Layer 14 output shape: 13 x 13 x 512 x 1
+Layer 15 output shape: 13 x 13 x 255 x 1
+Layer 18 output shape: 13 x 13 x 128 x 1
+Layer 19 output shape: 26 x 26 x 128 x 1
+Layer 20 output shape: 26 x 26 x 384 x 1
+Layer 21 output shape: 26 x 26 x 256 x 1
+Layer 22 output shape: 26 x 26 x 255 x 1
+dog: 57%
+car: 52%
+truck: 56%
+car: 62%
+bicycle: 59%
+Detected objects saved in 'predictions.jpg' (time: 0.357000 sec.)
+```
\ No newline at end of file
--- /dev/null
+#!/usr/bin/env python3
+import sys
+import gguf
+import numpy as np
+
+def save_conv2d_layer(f, gguf_writer, prefix, inp_c, filters, size, batch_normalize=True):
+ biases = np.fromfile(f, dtype=np.float32, count=filters)
+ gguf_writer.add_tensor(prefix + "_biases", biases, raw_shape=(1, filters, 1, 1))
+
+ if batch_normalize:
+ scales = np.fromfile(f, dtype=np.float32, count=filters)
+ gguf_writer.add_tensor(prefix + "_scales", scales, raw_shape=(1, filters, 1, 1))
+ rolling_mean = np.fromfile(f, dtype=np.float32, count=filters)
+ gguf_writer.add_tensor(prefix + "_rolling_mean", rolling_mean, raw_shape=(1, filters, 1, 1))
+ rolling_variance = np.fromfile(f, dtype=np.float32, count=filters)
+ gguf_writer.add_tensor(prefix + "_rolling_variance", rolling_variance, raw_shape=(1, filters, 1, 1))
+
+ weights_count = filters * inp_c * size * size
+ l0_weights = np.fromfile(f, dtype=np.float32, count=weights_count)
+ ## ggml doesn't support f32 convolution yet, use f16 instead
+ l0_weights = l0_weights.astype(np.float16)
+ gguf_writer.add_tensor(prefix + "_weights", l0_weights, raw_shape=(filters, inp_c, size, size))
+
+
+if __name__ == '__main__':
+ if len(sys.argv) != 2:
+ print("Usage: %s <yolov3-tiny.weights>" % sys.argv[0])
+ sys.exit(1)
+ outfile = 'yolov3-tiny.gguf'
+ gguf_writer = gguf.GGUFWriter(outfile, 'yolov3-tiny')
+
+ f = open(sys.argv[1], 'rb')
+ f.read(20) # skip header
+ save_conv2d_layer(f, gguf_writer, "l0", 3, 16, 3)
+ save_conv2d_layer(f, gguf_writer, "l1", 16, 32, 3)
+ save_conv2d_layer(f, gguf_writer, "l2", 32, 64, 3)
+ save_conv2d_layer(f, gguf_writer, "l3", 64, 128, 3)
+ save_conv2d_layer(f, gguf_writer, "l4", 128, 256, 3)
+ save_conv2d_layer(f, gguf_writer, "l5", 256, 512, 3)
+ save_conv2d_layer(f, gguf_writer, "l6", 512, 1024, 3)
+ save_conv2d_layer(f, gguf_writer, "l7", 1024, 256, 1)
+ save_conv2d_layer(f, gguf_writer, "l8", 256, 512, 3)
+ save_conv2d_layer(f, gguf_writer, "l9", 512, 255, 1, batch_normalize=False)
+ save_conv2d_layer(f, gguf_writer, "l10", 256, 128, 1)
+ save_conv2d_layer(f, gguf_writer, "l11", 384, 256, 3)
+ save_conv2d_layer(f, gguf_writer, "l12", 256, 255, 1, batch_normalize=False)
+ f.close()
+
+ gguf_writer.write_header_to_file()
+ gguf_writer.write_kv_data_to_file()
+ gguf_writer.write_tensors_to_file()
+ gguf_writer.close()
+ print("{} converted to {}".format(sys.argv[1], outfile))
--- /dev/null
+person
+bicycle
+car
+motorbike
+aeroplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+backpack
+umbrella
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+wine glass
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+sofa
+pottedplant
+bed
+diningtable
+toilet
+tvmonitor
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush
--- /dev/null
+#define STB_IMAGE_IMPLEMENTATION
+#include "stb_image.h"
+#define STB_IMAGE_WRITE_IMPLEMENTATION
+#include "stb_image_write.h"
+
+#include "yolo-image.h"
+
+static void draw_box(yolo_image & a, int x1, int y1, int x2, int y2, float r, float g, float b)
+{
+ if (x1 < 0) x1 = 0;
+ if (x1 >= a.w) x1 = a.w-1;
+ if (x2 < 0) x2 = 0;
+ if (x2 >= a.w) x2 = a.w-1;
+
+ if (y1 < 0) y1 = 0;
+ if (y1 >= a.h) y1 = a.h-1;
+ if (y2 < 0) y2 = 0;
+ if (y2 >= a.h) y2 = a.h-1;
+
+ for (int i = x1; i <= x2; ++i){
+ a.data[i + y1*a.w + 0*a.w*a.h] = r;
+ a.data[i + y2*a.w + 0*a.w*a.h] = r;
+
+ a.data[i + y1*a.w + 1*a.w*a.h] = g;
+ a.data[i + y2*a.w + 1*a.w*a.h] = g;
+
+ a.data[i + y1*a.w + 2*a.w*a.h] = b;
+ a.data[i + y2*a.w + 2*a.w*a.h] = b;
+ }
+ for (int i = y1; i <= y2; ++i){
+ a.data[x1 + i*a.w + 0*a.w*a.h] = r;
+ a.data[x2 + i*a.w + 0*a.w*a.h] = r;
+
+ a.data[x1 + i*a.w + 1*a.w*a.h] = g;
+ a.data[x2 + i*a.w + 1*a.w*a.h] = g;
+
+ a.data[x1 + i*a.w + 2*a.w*a.h] = b;
+ a.data[x2 + i*a.w + 2*a.w*a.h] = b;
+ }
+}
+
+void draw_box_width(yolo_image & a, int x1, int y1, int x2, int y2, int w, float r, float g, float b)
+{
+ for (int i = 0; i < w; ++i) {
+ draw_box(a, x1+i, y1+i, x2-i, y2-i, r, g, b);
+ }
+}
+
+bool save_image(const yolo_image & im, const char *name, int quality)
+{
+ uint8_t *data = (uint8_t*)calloc(im.w*im.h*im.c, sizeof(uint8_t));
+ for (int k = 0; k < im.c; ++k) {
+ for (int i = 0; i < im.w*im.h; ++i) {
+ data[i*im.c+k] = (uint8_t) (255*im.data[i + k*im.w*im.h]);
+ }
+ }
+ int success = stbi_write_jpg(name, im.w, im.h, im.c, data, quality);
+ free(data);
+ if (!success) {
+ fprintf(stderr, "Failed to write image %s\n", name);
+ return false;
+ }
+ return true;
+}
+
+bool load_image(const char *fname, yolo_image & img)
+{
+ int w, h, c;
+ uint8_t * data = stbi_load(fname, &w, &h, &c, 3);
+ if (!data) {
+ return false;
+ }
+ c = 3;
+ img.w = w;
+ img.h = h;
+ img.c = c;
+ img.data.resize(w*h*c);
+ for (int k = 0; k < c; ++k){
+ for (int j = 0; j < h; ++j){
+ for (int i = 0; i < w; ++i){
+ int dst_index = i + w*j + w*h*k;
+ int src_index = k + c*i + c*w*j;
+ img.data[dst_index] = (float)data[src_index]/255.;
+ }
+ }
+ }
+ stbi_image_free(data);
+ return true;
+}
+
+static yolo_image resize_image(const yolo_image & im, int w, int h)
+{
+ yolo_image resized(w, h, im.c);
+ yolo_image part(w, im.h, im.c);
+ float w_scale = (float)(im.w - 1) / (w - 1);
+ float h_scale = (float)(im.h - 1) / (h - 1);
+ for (int k = 0; k < im.c; ++k){
+ for (int r = 0; r < im.h; ++r) {
+ for (int c = 0; c < w; ++c) {
+ float val = 0;
+ if (c == w-1 || im.w == 1){
+ val = im.get_pixel(im.w-1, r, k);
+ } else {
+ float sx = c*w_scale;
+ int ix = (int) sx;
+ float dx = sx - ix;
+ val = (1 - dx) * im.get_pixel(ix, r, k) + dx * im.get_pixel(ix+1, r, k);
+ }
+ part.set_pixel(c, r, k, val);
+ }
+ }
+ }
+ for (int k = 0; k < im.c; ++k){
+ for (int r = 0; r < h; ++r){
+ float sy = r*h_scale;
+ int iy = (int) sy;
+ float dy = sy - iy;
+ for (int c = 0; c < w; ++c){
+ float val = (1-dy) * part.get_pixel(c, iy, k);
+ resized.set_pixel(c, r, k, val);
+ }
+ if (r == h-1 || im.h == 1) continue;
+ for (int c = 0; c < w; ++c){
+ float val = dy * part.get_pixel(c, iy+1, k);
+ resized.add_pixel(c, r, k, val);
+ }
+ }
+ }
+ return resized;
+}
+
+static void embed_image(const yolo_image & source, yolo_image & dest, int dx, int dy)
+{
+ for (int k = 0; k < source.c; ++k) {
+ for (int y = 0; y < source.h; ++y) {
+ for (int x = 0; x < source.w; ++x) {
+ float val = source.get_pixel(x, y, k);
+ dest.set_pixel(dx+x, dy+y, k, val);
+ }
+ }
+ }
+}
+
+yolo_image letterbox_image(const yolo_image & im, int w, int h)
+{
+ int new_w = im.w;
+ int new_h = im.h;
+ if (((float)w/im.w) < ((float)h/im.h)) {
+ new_w = w;
+ new_h = (im.h * w)/im.w;
+ } else {
+ new_h = h;
+ new_w = (im.w * h)/im.h;
+ }
+ yolo_image resized = resize_image(im, new_w, new_h);
+ yolo_image boxed(w, h, im.c);
+ boxed.fill(0.5);
+ embed_image(resized, boxed, (w-new_w)/2, (h-new_h)/2);
+ return boxed;
+}
+
+static yolo_image tile_images(const yolo_image & a, const yolo_image & b, int dx)
+{
+ if (a.w == 0) {
+ return b;
+ }
+ yolo_image c(a.w + b.w + dx, (a.h > b.h) ? a.h : b.h, a.c);
+ c.fill(1.0f);
+ embed_image(a, c, 0, 0);
+ embed_image(b, c, a.w + dx, 0);
+ return c;
+}
+
+static yolo_image border_image(const yolo_image & a, int border)
+{
+ yolo_image b(a.w + 2*border, a.h + 2*border, a.c);
+ b.fill(1.0f);
+ embed_image(a, b, border, border);
+ return b;
+}
+
+yolo_image get_label(const std::vector<yolo_image> & alphabet, const std::string & label, int size)
+{
+ size = size/10;
+ size = std::min(size, 7);
+ yolo_image result(0,0,0);
+ for (int i = 0; i < (int)label.size(); ++i) {
+ int ch = label[i];
+ yolo_image img = alphabet[size*128 + ch];
+ result = tile_images(result, img, -size - 1 + (size+1)/2);
+ }
+ return border_image(result, (int)(result.h*.25));
+}
+
+void draw_label(yolo_image & im, int row, int col, const yolo_image & label, const float * rgb)
+{
+ int w = label.w;
+ int h = label.h;
+ if (row - h >= 0) {
+ row = row - h;
+ }
+ for (int j = 0; j < h && j + row < im.h; j++) {
+ for (int i = 0; i < w && i + col < im.w; i++) {
+ for (int k = 0; k < label.c; k++) {
+ float val = label.get_pixel(i, j, k);
+ im.set_pixel(i + col, j + row, k, rgb[k] * val);
+ }
+ }
+ }
+}
\ No newline at end of file
--- /dev/null
+#pragma once
+
+#include <string>
+#include <vector>
+#include <cassert>
+
+struct yolo_image {
+ int w, h, c;
+ std::vector<float> data;
+
+ yolo_image() : w(0), h(0), c(0) {}
+ yolo_image(int w, int h, int c) : w(w), h(h), c(c), data(w*h*c) {}
+
+ float get_pixel(int x, int y, int c) const {
+ assert(x >= 0 && x < w && y >= 0 && y < h && c >= 0 && c < this->c);
+ return data[c*w*h + y*w + x];
+ }
+
+ void set_pixel(int x, int y, int c, float val) {
+ assert(x >= 0 && x < w && y >= 0 && y < h && c >= 0 && c < this->c);
+ data[c*w*h + y*w + x] = val;
+ }
+
+ void add_pixel(int x, int y, int c, float val) {
+ assert(x >= 0 && x < w && y >= 0 && y < h && c >= 0 && c < this->c);
+ data[c*w*h + y*w + x] += val;
+ }
+
+ void fill(float val) {
+ std::fill(data.begin(), data.end(), val);
+ }
+};
+
+bool load_image(const char *fname, yolo_image & img);
+void draw_box_width(yolo_image & a, int x1, int y1, int x2, int y2, int w, float r, float g, float b);
+yolo_image letterbox_image(const yolo_image & im, int w, int h);
+bool save_image(const yolo_image & im, const char *name, int quality);
+yolo_image get_label(const std::vector<yolo_image> & alphabet, const std::string & label, int size);
+void draw_label(yolo_image & im, int row, int col, const yolo_image & label, const float * rgb);
--- /dev/null
+#include "ggml/ggml.h"
+#include "yolo-image.h"
+
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <ctime>
+#include <string>
+#include <vector>
+#include <algorithm>
+#include <fstream>
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+struct conv2d_layer {
+ struct ggml_tensor * weights;
+ struct ggml_tensor * biases;
+ struct ggml_tensor * scales;
+ struct ggml_tensor * rolling_mean;
+ struct ggml_tensor * rolling_variance;
+ int padding = 1;
+ bool batch_normalize = true;
+ bool activate = true; // true for leaky relu, false for linear
+};
+
+struct yolo_model {
+ int width = 416;
+ int height = 416;
+ std::vector<conv2d_layer> conv2d_layers;
+ struct ggml_context * ctx;
+};
+
+struct yolo_layer {
+ int classes = 80;
+ std::vector<int> mask;
+ std::vector<float> anchors;
+ struct ggml_tensor * predictions;
+
+ yolo_layer(int classes, const std::vector<int> & mask, const std::vector<float> & anchors, struct ggml_tensor * predictions)
+ : classes(classes), mask(mask), anchors(anchors), predictions(predictions)
+ { }
+
+ int entry_index(int location, int entry) const {
+ int w = predictions->ne[0];
+ int h = predictions->ne[1];
+ int n = location / (w*h);
+ int loc = location % (w*h);
+ return n*w*h*(4+classes+1) + entry*w*h + loc;
+ }
+};
+
+struct box {
+ float x, y, w, h;
+};
+
+struct detection {
+ box bbox;
+ std::vector<float> prob;
+ float objectness;
+};
+
+static bool load_model(const std::string & fname, yolo_model & model) {
+ struct gguf_init_params params = {
+ /*.no_alloc =*/ false,
+ /*.ctx =*/ &model.ctx,
+ };
+ gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
+ if (!ctx) {
+ fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
+ return false;
+ }
+ model.width = 416;
+ model.height = 416;
+ model.conv2d_layers.resize(13);
+ model.conv2d_layers[7].padding = 0;
+ model.conv2d_layers[9].padding = 0;
+ model.conv2d_layers[9].batch_normalize = false;
+ model.conv2d_layers[9].activate = false;
+ model.conv2d_layers[10].padding = 0;
+ model.conv2d_layers[12].padding = 0;
+ model.conv2d_layers[12].batch_normalize = false;
+ model.conv2d_layers[12].activate = false;
+ for (int i = 0; i < (int)model.conv2d_layers.size(); i++) {
+ char name[256];
+ snprintf(name, sizeof(name), "l%d_weights", i);
+ model.conv2d_layers[i].weights = ggml_get_tensor(model.ctx, name);
+ snprintf(name, sizeof(name), "l%d_biases", i);
+ model.conv2d_layers[i].biases = ggml_get_tensor(model.ctx, name);
+ if (model.conv2d_layers[i].batch_normalize) {
+ snprintf(name, sizeof(name), "l%d_scales", i);
+ model.conv2d_layers[i].scales = ggml_get_tensor(model.ctx, name);
+ snprintf(name, sizeof(name), "l%d_rolling_mean", i);
+ model.conv2d_layers[i].rolling_mean = ggml_get_tensor(model.ctx, name);
+ snprintf(name, sizeof(name), "l%d_rolling_variance", i);
+ model.conv2d_layers[i].rolling_variance = ggml_get_tensor(model.ctx, name);
+ }
+ }
+ return true;
+}
+
+static bool load_labels(const char * filename, std::vector<std::string> & labels)
+{
+ std::ifstream file_in(filename);
+ if (!file_in) {
+ return false;
+ }
+ std::string line;
+ while (std::getline(file_in, line)) {
+ labels.push_back(line);
+ }
+ GGML_ASSERT(labels.size() == 80);
+ return true;
+}
+
+static bool load_alphabet(std::vector<yolo_image> & alphabet)
+{
+ alphabet.resize(8 * 128);
+ for (int j = 0; j < 8; j++) {
+ for (int i = 32; i < 127; i++) {
+ char fname[256];
+ sprintf(fname, "data/labels/%d_%d.png", i, j);
+ if (!load_image(fname, alphabet[j*128 + i])) {
+ fprintf(stderr, "Cannot load '%s'\n", fname);
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+static ggml_tensor * apply_conv2d(ggml_context * ctx, ggml_tensor * input, const conv2d_layer & layer)
+{
+ struct ggml_tensor * result = ggml_conv_2d(ctx, layer.weights, input, 1, 1, layer.padding, layer.padding, 1, 1);
+ if (layer.batch_normalize) {
+ result = ggml_sub(ctx, result, ggml_repeat(ctx, layer.rolling_mean, result));
+ result = ggml_div(ctx, result, ggml_sqrt(ctx, ggml_repeat(ctx, layer.rolling_variance, result)));
+ result = ggml_mul(ctx, result, ggml_repeat(ctx, layer.scales, result));
+ }
+ result = ggml_add(ctx, result, ggml_repeat(ctx, layer.biases, result));
+ if (layer.activate) {
+ result = ggml_leaky(ctx, result);
+ }
+ return result;
+}
+
+static void activate_array(float * x, const int n)
+{
+ // logistic activation
+ for (int i = 0; i < n; i++) {
+ x[i] = 1./(1. + exp(-x[i]));
+ }
+}
+
+static void apply_yolo(yolo_layer & layer)
+{
+ int w = layer.predictions->ne[0];
+ int h = layer.predictions->ne[1];
+ int N = layer.mask.size();
+ float * data = ggml_get_data_f32(layer.predictions);
+ for (int n = 0; n < N; n++) {
+ int index = layer.entry_index(n*w*h, 0);
+ activate_array(data + index, 2*w*h);
+ index = layer.entry_index(n*w*h, 4);
+ activate_array(data + index, (1+layer.classes)*w*h);
+ }
+}
+
+static box get_yolo_box(const yolo_layer & layer, int n, int index, int i, int j, int lw, int lh, int w, int h, int stride)
+{
+ float * predictions = ggml_get_data_f32(layer.predictions);
+ box b;
+ b.x = (i + predictions[index + 0*stride]) / lw;
+ b.y = (j + predictions[index + 1*stride]) / lh;
+ b.w = exp(predictions[index + 2*stride]) * layer.anchors[2*n] / w;
+ b.h = exp(predictions[index + 3*stride]) * layer.anchors[2*n+1] / h;
+ return b;
+}
+
+static void correct_yolo_box(box & b, int im_w, int im_h, int net_w, int net_h)
+{
+ int new_w = 0;
+ int new_h = 0;
+ if (((float)net_w/im_w) < ((float)net_h/im_h)) {
+ new_w = net_w;
+ new_h = (im_h * net_w)/im_w;
+ } else {
+ new_h = net_h;
+ new_w = (im_w * net_h)/im_h;
+ }
+ b.x = (b.x - (net_w - new_w)/2./net_w) / ((float)new_w/net_w);
+ b.y = (b.y - (net_h - new_h)/2./net_h) / ((float)new_h/net_h);
+ b.w *= (float)net_w/new_w;
+ b.h *= (float)net_h/new_h;
+}
+
+static void get_yolo_detections(const yolo_layer & layer, std::vector<detection> & detections, int im_w, int im_h, int netw, int neth, float thresh)
+{
+ int w = layer.predictions->ne[0];
+ int h = layer.predictions->ne[1];
+ int N = layer.mask.size();
+ float * predictions = ggml_get_data_f32(layer.predictions);
+ std::vector<detection> result;
+ for (int i = 0; i < w*h; i++) {
+ for (int n = 0; n < N; n++) {
+ int obj_index = layer.entry_index(n*w*h + i, 4);
+ float objectness = predictions[obj_index];
+ if (objectness <= thresh) {
+ continue;
+ }
+ detection det;
+ int box_index = layer.entry_index(n*w*h + i, 0);
+ int row = i / w;
+ int col = i % w;
+ det.bbox = get_yolo_box(layer, layer.mask[n], box_index, col, row, w, h, netw, neth, w*h);
+ correct_yolo_box(det.bbox, im_w, im_h, netw, neth);
+ det.objectness = objectness;
+ det.prob.resize(layer.classes);
+ for (int j = 0; j < layer.classes; j++) {
+ int class_index = layer.entry_index(n*w*h + i, 4 + 1 + j);
+ float prob = objectness*predictions[class_index];
+ det.prob[j] = (prob > thresh) ? prob : 0;
+ }
+ detections.push_back(det);
+ }
+ }
+}
+
+static float overlap(float x1, float w1, float x2, float w2)
+{
+ float l1 = x1 - w1/2;
+ float l2 = x2 - w2/2;
+ float left = l1 > l2 ? l1 : l2;
+ float r1 = x1 + w1/2;
+ float r2 = x2 + w2/2;
+ float right = r1 < r2 ? r1 : r2;
+ return right - left;
+}
+
+static float box_intersection(const box & a, const box & b)
+{
+ float w = overlap(a.x, a.w, b.x, b.w);
+ float h = overlap(a.y, a.h, b.y, b.h);
+ if (w < 0 || h < 0) return 0;
+ float area = w*h;
+ return area;
+}
+
+static float box_union(const box & a, const box & b)
+{
+ float i = box_intersection(a, b);
+ float u = a.w*a.h + b.w*b.h - i;
+ return u;
+}
+
+static float box_iou(const box & a, const box & b)
+{
+ return box_intersection(a, b)/box_union(a, b);
+}
+
+static void do_nms_sort(std::vector<detection> & dets, int classes, float thresh)
+{
+ int k = (int)dets.size()-1;
+ for (int i = 0; i <= k; ++i) {
+ if (dets[i].objectness == 0) {
+ std::swap(dets[i], dets[k]);
+ --k;
+ --i;
+ }
+ }
+ int total = k+1;
+ for (int k = 0; k < classes; ++k) {
+ std::sort(dets.begin(), dets.begin()+total, [=](const detection & a, const detection & b) {
+ return a.prob[k] > b.prob[k];
+ });
+ for (int i = 0; i < total; ++i) {
+ if (dets[i].prob[k] == 0) {
+ continue;
+ }
+ box a = dets[i].bbox;
+ for (int j = i+1; j < total; ++j){
+ box b = dets[j].bbox;
+ if (box_iou(a, b) > thresh) {
+ dets[j].prob[k] = 0;
+ }
+ }
+ }
+ }
+}
+
+static float get_color(int c, int x, int max)
+{
+ float colors[6][3] = { {1,0,1}, {0,0,1}, {0,1,1}, {0,1,0}, {1,1,0}, {1,0,0} };
+ float ratio = ((float)x/max)*5;
+ int i = floor(ratio);
+ int j = ceil(ratio);
+ ratio -= i;
+ float r = (1-ratio) * colors[i][c] + ratio*colors[j][c];
+ return r;
+}
+
+static void draw_detections(yolo_image & im, const std::vector<detection> & dets, float thresh, const std::vector<std::string> & labels, const std::vector<yolo_image> & alphabet)
+{
+ int classes = (int)labels.size();
+ for (int i = 0; i < (int)dets.size(); i++) {
+ std::string labelstr;
+ int cl = -1;
+ for (int j = 0; j < (int)dets[i].prob.size(); j++) {
+ if (dets[i].prob[j] > thresh) {
+ if (cl < 0) {
+ labelstr = labels[j];
+ cl = j;
+ } else {
+ labelstr += ", ";
+ labelstr += labels[j];
+ }
+ printf("%s: %.0f%%\n", labels[j].c_str(), dets[i].prob[j]*100);
+ }
+ }
+ if (cl >= 0) {
+ int width = im.h * .006;
+ int offset = cl*123457 % classes;
+ float red = get_color(2,offset,classes);
+ float green = get_color(1,offset,classes);
+ float blue = get_color(0,offset,classes);
+ float rgb[3];
+
+ rgb[0] = red;
+ rgb[1] = green;
+ rgb[2] = blue;
+ box b = dets[i].bbox;
+
+ int left = (b.x-b.w/2.)*im.w;
+ int right = (b.x+b.w/2.)*im.w;
+ int top = (b.y-b.h/2.)*im.h;
+ int bot = (b.y+b.h/2.)*im.h;
+
+ if (left < 0) left = 0;
+ if (right > im.w-1) right = im.w-1;
+ if (top < 0) top = 0;
+ if (bot > im.h-1) bot = im.h-1;
+
+ draw_box_width(im, left, top, right, bot, width, red, green, blue);
+ yolo_image label = get_label(alphabet, labelstr, (im.h*.03));
+ draw_label(im, top + width, left, label, rgb);
+ }
+ }
+}
+
+static void print_shape(int layer, const ggml_tensor * t)
+{
+ printf("Layer %2d output shape: %3d x %3d x %4d x %3d\n", layer, (int)t->ne[0], (int)t->ne[1], (int)t->ne[2], (int)t->ne[3]);
+}
+
+void detect(yolo_image & img, const yolo_model & model, float thresh, const std::vector<std::string> & labels, const std::vector<yolo_image> & alphabet)
+{
+ static size_t buf_size = 20000000 * sizeof(float) * 4;
+ static void * buf = malloc(buf_size);
+
+ struct ggml_init_params params = {
+ /*.mem_size =*/ buf_size,
+ /*.mem_buffer =*/ buf,
+ /*.no_alloc =*/ false,
+ };
+
+ struct ggml_context * ctx0 = ggml_init(params);
+ struct ggml_cgraph gf = {};
+ std::vector<detection> detections;
+
+ yolo_image sized = letterbox_image(img, model.width, model.height);
+ struct ggml_tensor * input = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, model.width, model.height, 3, 1);
+ std::memcpy(input->data, sized.data.data(), ggml_nbytes(input));
+ ggml_set_name(input, "input");
+
+ struct ggml_tensor * result = apply_conv2d(ctx0, input, model.conv2d_layers[0]);
+ print_shape(0, result);
+ result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+ print_shape(1, result);
+ result = apply_conv2d(ctx0, result, model.conv2d_layers[1]);
+ print_shape(2, result);
+ result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+ print_shape(3, result);
+ result = apply_conv2d(ctx0, result, model.conv2d_layers[2]);
+ print_shape(4, result);
+ result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+ print_shape(5, result);
+ result = apply_conv2d(ctx0, result, model.conv2d_layers[3]);
+ print_shape(6, result);
+ result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+ print_shape(7, result);
+ result = apply_conv2d(ctx0, result, model.conv2d_layers[4]);
+ struct ggml_tensor * layer_8 = result;
+ print_shape(8, result);
+ result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+ print_shape(9, result);
+ result = apply_conv2d(ctx0, result, model.conv2d_layers[5]);
+ print_shape(10, result);
+ result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 1, 1, 0.5, 0.5);
+ print_shape(11, result);
+ result = apply_conv2d(ctx0, result, model.conv2d_layers[6]);
+ print_shape(12, result);
+ result = apply_conv2d(ctx0, result, model.conv2d_layers[7]);
+ struct ggml_tensor * layer_13 = result;
+ print_shape(13, result);
+ result = apply_conv2d(ctx0, result, model.conv2d_layers[8]);
+ print_shape(14, result);
+ result = apply_conv2d(ctx0, result, model.conv2d_layers[9]);
+ struct ggml_tensor * layer_15 = result;
+ print_shape(15, result);
+ result = apply_conv2d(ctx0, layer_13, model.conv2d_layers[10]);
+ print_shape(18, result);
+ result = ggml_upscale(ctx0, result, 2);
+ print_shape(19, result);
+ result = ggml_concat(ctx0, result, layer_8);
+ print_shape(20, result);
+ result = apply_conv2d(ctx0, result, model.conv2d_layers[11]);
+ print_shape(21, result);
+ result = apply_conv2d(ctx0, result, model.conv2d_layers[12]);
+ struct ggml_tensor * layer_22 = result;
+ print_shape(22, result);
+
+ ggml_build_forward_expand(&gf, layer_15);
+ ggml_build_forward_expand(&gf, layer_22);
+ ggml_graph_compute_with_ctx(ctx0, &gf, 1);
+
+ yolo_layer yolo16{ 80, {3, 4, 5}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_15};
+ apply_yolo(yolo16);
+ get_yolo_detections(yolo16, detections, img.w, img.h, model.width, model.height, thresh);
+
+ yolo_layer yolo23{ 80, {0, 1, 2}, {10, 14, 23, 27, 37,58, 81, 82, 135, 169, 344, 319}, layer_22};
+ apply_yolo(yolo23);
+ get_yolo_detections(yolo23, detections, img.w, img.h, model.width, model.height, thresh);
+
+ do_nms_sort(detections, yolo23.classes, .45);
+ draw_detections(img, detections, thresh, labels, alphabet);
+ ggml_free(ctx0);
+}
+
+struct yolo_params {
+ float thresh = 0.5;
+ std::string model = "yolov3-tiny.gguf";
+ std::string fname_inp = "input.jpg";
+ std::string fname_out = "predictions.jpg";
+};
+
+void yolo_print_usage(int argc, char ** argv, const yolo_params & params) {
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
+ fprintf(stderr, "\n");
+ fprintf(stderr, "options:\n");
+ fprintf(stderr, " -h, --help show this help message and exit\n");
+ fprintf(stderr, " -th T, --thresh T detection threshold (default: %.2f)\n", params.thresh);
+ fprintf(stderr, " -m FNAME, --model FNAME\n");
+ fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
+ fprintf(stderr, " -i FNAME, --inp FNAME\n");
+ fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str());
+ fprintf(stderr, " -o FNAME, --out FNAME\n");
+ fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str());
+ fprintf(stderr, "\n");
+}
+
+bool yolo_params_parse(int argc, char ** argv, yolo_params & params) {
+ for (int i = 1; i < argc; i++) {
+ std::string arg = argv[i];
+
+ if (arg == "-th" || arg == "--thresh") {
+ params.thresh = std::stof(argv[++i]);
+ } else if (arg == "-m" || arg == "--model") {
+ params.model = argv[++i];
+ } else if (arg == "-i" || arg == "--inp") {
+ params.fname_inp = argv[++i];
+ } else if (arg == "-o" || arg == "--out") {
+ params.fname_out = argv[++i];
+ } else if (arg == "-h" || arg == "--help") {
+ yolo_print_usage(argc, argv, params);
+ exit(0);
+ } else {
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+ yolo_print_usage(argc, argv, params);
+ exit(0);
+ }
+ }
+
+ return true;
+}
+
+int main(int argc, char *argv[])
+{
+ ggml_time_init();
+ yolo_model model;
+
+ yolo_params params;
+ if (!yolo_params_parse(argc, argv, params)) {
+ return 1;
+ }
+ if (!load_model(params.model, model)) {
+ fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+ return 1;
+ }
+ yolo_image img(0,0,0);
+ if (!load_image(params.fname_inp.c_str(), img)) {
+ fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, params.fname_inp.c_str());
+ return 1;
+ }
+ std::vector<std::string> labels;
+ if (!load_labels("data/coco.names", labels)) {
+ fprintf(stderr, "%s: failed to load labels from 'data/coco.names'\n", __func__);
+ return 1;
+ }
+ std::vector<yolo_image> alphabet;
+ if (!load_alphabet(alphabet)) {
+ fprintf(stderr, "%s: failed to load alphabet\n", __func__);
+ return 1;
+ }
+ const int64_t t_start_ms = ggml_time_ms();
+ detect(img, model, params.thresh, labels, alphabet);
+ const int64_t t_detect_ms = ggml_time_ms() - t_start_ms;
+ if (!save_image(img, params.fname_out.c_str(), 80)) {
+ fprintf(stderr, "%s: failed to save image to '%s'\n", __func__, params.fname_out.c_str());
+ return 1;
+ }
+ printf("Detected objects saved in '%s' (time: %f sec.)\n", params.fname_out.c_str(), t_detect_ms / 1000.0f);
+ ggml_free(model.ctx);
+ return 0;
+}
GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_QUICK,
GGML_UNARY_OP_SILU,
+ GGML_UNARY_OP_LEAKY
};
enum ggml_object_type {
struct ggml_context * ctx,
struct ggml_tensor * a);
+ GGML_API struct ggml_tensor * ggml_leaky(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
GGML_API struct ggml_tensor * ggml_relu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
int s0, // stride
int p0); // padding
+ // the result will have 2*p0 padding for the first dimension
+ // and 2*p1 padding for the second dimension
GGML_API struct ggml_tensor * ggml_pool_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k1,
int s0,
int s1,
- int p0,
- int p1);
+ float p0,
+ float p1);
// nearest interpolate
// used in stable-diffusion
torch==2.0.1
torchaudio==2.0.2
torchvision==0.15.2
-transformers==4.29.2
\ No newline at end of file
+transformers==4.29.2
+gguf==0.4.5
inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+inline static void ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; }
static const float GELU_COEF_A = 0.044715f;
static const float GELU_QUICK_COEF = -1.702f;
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
}
+// ggml_leaky
+
+struct ggml_tensor * ggml_leaky(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_LEAKY);
+}
+
// ggml_gelu
struct ggml_tensor * ggml_gelu(
// ggml_pool_*
-static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) {
+static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) {
return (ins + 2 * p - ks) / s + 1;
}
int k1,
int s0,
int s1,
- int p0,
- int p1) {
+ float p0,
+ float p1) {
bool is_node = false;
GGML_ASSERT(false); // TODO: implement backward
is_node = true;
}
-
const int64_t ne[3] = {
ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
}
}
+// ggml_compute_forward_leaky
+
+static void ggml_compute_forward_leaky_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_leaky_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_leaky(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_leaky_f32(params, src0, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
// ggml_compute_forward_silu_back
static void ggml_compute_forward_silu_back_f32(
ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst);
}
-// ggml_compute_forward_pool_2d_sk_p0
+// ggml_compute_forward_pool_2d
-static void ggml_compute_forward_pool_2d_sk_p0(
+static void ggml_compute_forward_pool_2d(
const struct ggml_compute_params * params,
- const enum ggml_op_pool op,
const struct ggml_tensor * src,
- const int k0,
- const int k1,
struct ggml_tensor * dst) {
assert(src->type == GGML_TYPE_F32);
assert(params->ith == 0);
return;
}
+ const int32_t * opts = (const int32_t *)dst->op_params;
+ enum ggml_op_pool op = opts[0];
+ const int k0 = opts[1];
+ const int k1 = opts[2];
+ const int s0 = opts[3];
+ const int s1 = opts[4];
+ const int p0 = opts[5];
+ const int p1 = opts[6];
const char * cdata = (const char*)src->data;
const char * const data_end = cdata + ggml_nbytes(src);
float * dplane = (float *)dst->data;
const int ka = k0 * k1;
+ const int offset0 = -p0;
+ const int offset1 = -p1;
while (cdata < data_end) {
for (int oy = 0; oy < py; ++oy) {
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
}
- const int ix = ox * k0;
- const int iy = oy * k1;
+ const int ix = offset0 + ox * s0;
+ const int iy = offset1 + oy * s1;
for (int ky = 0; ky < k1; ++ky) {
+ if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
for (int kx = 0; kx < k0; ++kx) {
int j = ix + kx;
+ if (j < 0 || j >= src->ne[0]) continue;
switch (op) {
case GGML_OP_POOL_AVG: *out += srow[j]; break;
case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break;
}
}
-// ggml_compute_forward_pool_2d
-
-static void ggml_compute_forward_pool_2d(
- const struct ggml_compute_params * params,
- const struct ggml_tensor * src0,
- struct ggml_tensor * dst) {
-
- const int32_t * opts = (const int32_t *)dst->op_params;
- enum ggml_op_pool op = opts[0];
- const int k0 = opts[1];
- const int k1 = opts[2];
- const int s0 = opts[3];
- const int s1 = opts[4];
- const int p0 = opts[5];
- const int p1 = opts[6];
- GGML_ASSERT(p0 == 0);
- GGML_ASSERT(p1 == 0); // padding not supported
- GGML_ASSERT(k0 == s0);
- GGML_ASSERT(k1 == s1); // only s = k supported
-
- ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst);
-}
-
// ggml_compute_forward_upscale
static void ggml_compute_forward_upscale_f32(
{
ggml_compute_forward_silu(params, src0, dst);
} break;
+ case GGML_UNARY_OP_LEAKY:
+ {
+ ggml_compute_forward_leaky(params, src0, dst);
+ } break;
default:
{
GGML_ASSERT(false);
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_LEAKY:
{
n_tasks = 1;
} break;