From: slaren Date: Sun, 25 Feb 2024 19:41:35 +0000 (+0100) Subject: add google magika inference example (#748) X-Git-Tag: upstream/0.0.1642~911 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=b458250b736a7473f7ff3560d47c93f1644f3290;p=pkg%2Fggml%2Fsources%2Fggml add google magika inference example (#748) * add magika inference example * ggml : fix unaligned accesses in custom ops * ggml : fix FP32 GELU for values that exceed the FP16 range * use ggml_pool_1d * add README * Update README.md * pad inputs if the files are too small * cleanup ggml-ci --- diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 5a268dca..d3bf460b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -24,3 +24,4 @@ add_subdirectory(whisper) add_subdirectory(mnist) add_subdirectory(sam) add_subdirectory(yolo) +add_subdirectory(magika) diff --git a/examples/magika/CMakeLists.txt b/examples/magika/CMakeLists.txt new file mode 100644 index 00000000..5543237b --- /dev/null +++ b/examples/magika/CMakeLists.txt @@ -0,0 +1,21 @@ +# +# magika + +set(TEST_TARGET magika) +add_executable(${TEST_TARGET} main.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml) + +# +# For GPU offloading + +if (GGML_CUBLAS) + add_compile_definitions(GGML_USE_CUBLAS) +endif() + +if (GGML_CLBLAST) + add_compile_definitions(GGML_USE_CLBLAST) +endif() + +if (GGML_METAL) + add_compile_definitions(GGML_USE_METAL) +endif() diff --git a/examples/magika/README.md b/examples/magika/README.md new file mode 100644 index 00000000..8e1ca27d --- /dev/null +++ b/examples/magika/README.md @@ -0,0 +1,23 @@ +# Google Magika inference + +Simple example that shows how to use GGML for inference with the [Google Magika](https://github.com/google/magika) file type detection model. + +### Usage + +- Obtain the Magika model in H5 format + - Pinned version: https://github.com/google/magika/blob/4460acb5d3f86807c3b53223229dee2afa50c025/assets_generation/models/standard_v1/model.h5 +- Use `convert.py` to convert the model to gguf format: +```sh + $ python examples/magika/convert.py /path/to/model.h5 +``` +- Invoke the program with the model file and a list of files to identify: +```sh + $ build/bin/magika model.h5.gguf examples/sam/example.jpg examples/magika/convert.py README.md src/ggml.c /bin/gcc write.exe jfk.wav + examples/sam/example.jpg : jpeg (100.00%) pptx (0.00%) smali (0.00%) shell (0.00%) sevenzip (0.00%) + examples/magika/convert.py : python (99.99%) javascript (0.00%) txt (0.00%) asm (0.00%) scala (0.00%) + README.md : markdown (100.00%) txt (0.00%) yaml (0.00%) ppt (0.00%) shell (0.00%) + src/ggml.c : c (99.95%) txt (0.04%) asm (0.01%) yaml (0.00%) html (0.00%) + /bin/gcc : elf (99.98%) odex (0.02%) pptx (0.00%) smali (0.00%) shell (0.00%) + write.exe : pebin (100.00%) ppt (0.00%) smali (0.00%) shell (0.00%) sevenzip (0.00%) + jfk.wav : wav (100.00%) ppt (0.00%) shell (0.00%) sevenzip (0.00%) scala (0.00%) +``` diff --git a/examples/magika/convert.py b/examples/magika/convert.py new file mode 100644 index 00000000..b901a34f --- /dev/null +++ b/examples/magika/convert.py @@ -0,0 +1,32 @@ +import sys +from tensorflow import keras +import gguf + +def convert(model_name): + model = keras.models.load_model(model_name, compile=False) + gguf_model_name = model_name + ".gguf" + gguf_writer = gguf.GGUFWriter(gguf_model_name, "magika") + + for layer in model.layers: + # export layers with weights + if layer.weights: + for weight in layer.weights: + print(f" [{weight.name}] {weight.shape} {weight.dtype}") + weight_data = weight.numpy() + gguf_writer.add_tensor(weight.name, weight_data.T) + + + gguf_writer.write_header_to_file() + gguf_writer.write_kv_data_to_file() + gguf_writer.write_tensors_to_file() + gguf_writer.close() + print("Model converted and saved to '{}'".format(gguf_model_name)) + + +if __name__ == '__main__': + if len(sys.argv) > 1: + model_file = sys.argv[1] + else: + model_file = "model.h5" + + convert(model_file) diff --git a/examples/magika/main.cpp b/examples/magika/main.cpp new file mode 100644 index 00000000..d55b7960 --- /dev/null +++ b/examples/magika/main.cpp @@ -0,0 +1,371 @@ +#include "ggml/ggml.h" +#include "ggml/ggml-alloc.h" +#include "ggml/ggml-backend.h" +#include +#include +#include +#include +#include +#include + +static const char * magika_labels[] = { + "ai", "apk", "appleplist", "asm", "asp", + "batch", "bmp", "bzip", "c", "cab", + "cat", "chm", "coff", "crx", "cs", + "css", "csv", "deb", "dex", "dmg", + "doc", "docx", "elf", "emf", "eml", + "epub", "flac", "gif", "go", "gzip", + "hlp", "html", "ico", "ini", "internetshortcut", + "iso", "jar", "java", "javabytecode", "javascript", + "jpeg", "json", "latex", "lisp", "lnk", + "m3u", "macho", "makefile", "markdown", "mht", + "mp3", "mp4", "mscompress", "msi", "mum", + "odex", "odp", "ods", "odt", "ogg", + "outlook", "pcap", "pdf", "pebin", "pem", + "perl", "php", "png", "postscript", "powershell", + "ppt", "pptx", "python", "pythonbytecode", "rar", + "rdf", "rpm", "rst", "rtf", "ruby", + "rust", "scala", "sevenzip", "shell", "smali", + "sql", "squashfs", "svg", "swf", "symlinktext", + "tar", "tga", "tiff", "torrent", "ttf", + "txt", "unknown", "vba", "wav", "webm", + "webp", "winregistry", "wmf", "xar", "xls", + "xlsb", "xlsx", "xml", "xpi", "xz", + "yaml", "zip", "zlibstream" +}; + +struct magika_hparams { + const int block_size = 4096; + const int beg_size = 512; + const int mid_size = 512; + const int end_size = 512; + const int min_file_size_for_dl = 16; + const int n_label = 113; + const float f_norm_eps = 0.001f; + const int padding_token = 256; +}; + +struct magika_model { + ~magika_model() { + ggml_backend_buffer_free(buf_w); + ggml_backend_free(backend); + ggml_free(ctx_w); + } + + magika_hparams hparams; + + struct ggml_tensor * dense_w; + struct ggml_tensor * dense_b; + + struct ggml_tensor * layer_norm_gamma; + struct ggml_tensor * layer_norm_beta; + + struct ggml_tensor * dense_1_w; + struct ggml_tensor * dense_1_b; + + struct ggml_tensor * dense_2_w; + struct ggml_tensor * dense_2_b; + + struct ggml_tensor * layer_norm_1_gamma; + struct ggml_tensor * layer_norm_1_beta; + + struct ggml_tensor * target_label_w; + struct ggml_tensor * target_label_b; + + ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_backend_buffer_t buf_w = nullptr; + struct ggml_context * ctx_w = nullptr; +}; + +struct ggml_tensor * checked_get_tensor(struct ggml_context * ctx, const char * name) { + struct ggml_tensor * tensor = ggml_get_tensor(ctx, name); + if (!tensor) { + fprintf(stderr, "%s: tensor '%s' not found\n", __func__, name); + throw std::runtime_error("ggml_get_tensor() failed"); + } + return tensor; +} + +bool magika_model_load(const std::string & fname, magika_model & model) { + auto & ctx = model.ctx_w; + + struct gguf_init_params params = { + /*.no_alloc =*/ true, + /*.ctx =*/ &ctx, + }; + + struct gguf_context * ctx_gguf = gguf_init_from_file(fname.c_str(), params); + if (!ctx_gguf) { + fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__); + return false; + } + + model.buf_w = ggml_backend_alloc_ctx_tensors(ctx, model.backend); + if (!model.buf_w) { + fprintf(stderr, "%s: ggml_backend_alloc_ctx_tensors() failed\n", __func__); + gguf_free(ctx_gguf); + return false; + } + + try { + model.dense_w = checked_get_tensor(ctx, "dense/kernel:0"); + model.dense_b = checked_get_tensor(ctx, "dense/bias:0"); + + model.layer_norm_gamma = checked_get_tensor(ctx, "layer_normalization/gamma:0"); + model.layer_norm_beta = checked_get_tensor(ctx, "layer_normalization/beta:0"); + + model.dense_1_w = checked_get_tensor(ctx, "dense_1/kernel:0"); + model.dense_1_b = checked_get_tensor(ctx, "dense_1/bias:0"); + + model.dense_2_w = checked_get_tensor(ctx, "dense_2/kernel:0"); + model.dense_2_b = checked_get_tensor(ctx, "dense_2/bias:0"); + + model.layer_norm_1_gamma = checked_get_tensor(ctx, "layer_normalization_1/gamma:0"); + model.layer_norm_1_beta = checked_get_tensor(ctx, "layer_normalization_1/beta:0"); + + model.target_label_w = checked_get_tensor(ctx, "target_label/kernel:0"); + model.target_label_b = checked_get_tensor(ctx, "target_label/bias:0"); + } catch (const std::exception & e) { + fprintf(stderr, "%s: %s\n", __func__, e.what()); + gguf_free(ctx_gguf); + return false; + } + + FILE * f = fopen(fname.c_str(), "rb"); + if (!f) { + fprintf(stderr, "%s: fopen() failed\n", __func__); + gguf_free(ctx_gguf); + return false; + } + + const int n_tensors = gguf_get_n_tensors(ctx_gguf); + + for (int i = 0; i < n_tensors; i++) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + struct ggml_tensor * tensor = ggml_get_tensor(ctx, name); + size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i); + + //printf("%-30s: [%3ld, %3ld, %3ld, %3ld] %s\n", + // name, + // tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], + // ggml_type_name(tensor->type)); + + std::vector buf(ggml_nbytes(tensor)); + if (fseek(f, offs, SEEK_SET) != 0) { + fprintf(stderr, "%s: fseek() failed\n", __func__); + gguf_free(ctx_gguf); + fclose(f); + return false; + } + + if (fread(buf.data(), 1, buf.size(), f) != buf.size()) { + fprintf(stderr, "%s: fread() failed\n", __func__); + gguf_free(ctx_gguf); + fclose(f); + return false; + } + + ggml_backend_tensor_set(tensor, buf.data(), 0, buf.size()); + } + + fclose(f); + + gguf_free(ctx_gguf); + + return true; +} + +struct ggml_cgraph * magika_graph( + const magika_model & model, + const int n_files) { + + const auto & hparams = model.hparams; + + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + + struct ggml_cgraph * gf = ggml_new_graph(ctx); + + struct ggml_tensor * input = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 257, 1536, n_files); // one-hot + ggml_set_name(input, "input"); + ggml_set_input(input); + + struct ggml_tensor * cur; + + // dense + cur = ggml_mul_mat(ctx, model.dense_w, input); + cur = ggml_add(ctx, cur, model.dense_b); // [128, 1536, n_files] + cur = ggml_gelu(ctx, cur); + + // reshape + cur = ggml_reshape_3d(ctx, cur, 512, 384, n_files); // [384, 512, n_files] + cur = ggml_cont(ctx, ggml_transpose(ctx, cur)); + + // layer normalization + cur = ggml_norm(ctx, cur, hparams.f_norm_eps); + cur = ggml_mul(ctx, cur, model.layer_norm_gamma); // [384, 512, n_files] + cur = ggml_add(ctx, cur, model.layer_norm_beta); // [384, 512, n_files] + + // dense_1 + cur = ggml_cont(ctx, ggml_transpose(ctx, cur)); + cur = ggml_mul_mat(ctx, model.dense_1_w, cur); + cur = ggml_add(ctx, cur, model.dense_1_b); // [256, 384, n_files] + cur = ggml_gelu(ctx, cur); + + // dense_2 + cur = ggml_mul_mat(ctx, model.dense_2_w, cur); + cur = ggml_add(ctx, cur, model.dense_2_b); // [256, 384, n_files] + cur = ggml_gelu(ctx, cur); + + // global_max_pooling1d + cur = ggml_cont(ctx, ggml_transpose(ctx, cur)); // [384, 256, n_files] + cur = ggml_pool_1d(ctx, cur, GGML_OP_POOL_MAX, 384, 384, 0); // [1, 256, n_files] + cur = ggml_reshape_2d(ctx, cur, 256, n_files); // [256, n_files] + + // layer normalization 1 + cur = ggml_norm(ctx, cur, hparams.f_norm_eps); + cur = ggml_mul(ctx, cur, model.layer_norm_1_gamma); // [256, n_files] + cur = ggml_add(ctx, cur, model.layer_norm_1_beta); // [256, n_files] + + // target_label + cur = ggml_mul_mat(ctx, model.target_label_w, cur); + cur = ggml_add(ctx, cur, model.target_label_b); // [n_label, n_files] + cur = ggml_soft_max(ctx, cur); // [n_label, n_files] + ggml_set_name(cur, "target_label_probs"); + ggml_set_output(cur); + + ggml_build_forward_expand(gf, cur); + + return gf; +} + +bool magika_eval( + struct magika_model & model, + const std::vector & fnames) { + + const auto & hparams = model.hparams; + + static ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + struct ggml_cgraph * gf = magika_graph(model, fnames.size()); + + if (!ggml_gallocr_alloc_graph(alloc, gf)) { + fprintf(stderr, "%s: ggml_gallocr_alloc_graph() failed\n", __func__); + return false; + } + + struct ggml_tensor * input = ggml_graph_get_tensor(gf, "input"); + + for (size_t i = 0; i < fnames.size(); i++) { + FILE * f = fopen(fnames[i].c_str(), "rb"); + if (!f) { + fprintf(stderr, "%s: fopen() failed\n", __func__); + return false; + } + fseek(f, 0, SEEK_END); + long fsize = ftell(f); + + // the buffer is padded with the padding_token if the file is smaller than the block size + std::vector buf(1536, hparams.padding_token); + std::vector read_buf(std::max(hparams.beg_size, std::max(hparams.mid_size, hparams.end_size))); + + // read beg + fseek(f, 0, SEEK_SET); + int n_read = fread(read_buf.data(), 1, hparams.beg_size, f); + for (int j = 0; j < n_read; j++) { + // pad at the end + buf[j] = read_buf[j]; + } + + // read mid + long mid_offs = std::max(0L, (fsize - hparams.mid_size) / 2); + fseek(f, mid_offs, SEEK_SET); + n_read = fread(read_buf.data(), 1, hparams.mid_size, f); + for (int j = 0; j < n_read; j++) { + // pad at both ends + long mid_idx = hparams.beg_size + (hparams.mid_size / 2) - n_read / 2 + j; + buf[mid_idx] = read_buf[j]; + } + + // read end + long end_offs = std::max(0L, fsize - hparams.end_size); + fseek(f, end_offs, SEEK_SET); + n_read = fread(read_buf.data(), 1, hparams.end_size, f); + for (int j = 0; j < n_read; j++) { + // pad at the beginning + int end_idx = hparams.beg_size + hparams.mid_size + hparams.end_size - n_read + j; + buf[end_idx] = read_buf[j]; + } + + fclose(f); + + const size_t inp_bytes = hparams.beg_size + hparams.mid_size + hparams.end_size; + + // convert to one-hot + std::vector one_hot(257*inp_bytes); + for (size_t j = 0; j < inp_bytes; j++) { + one_hot[257*j + buf[j]] = 1.0f; + } + + ggml_backend_tensor_set(input, one_hot.data(), 257*inp_bytes*i*sizeof(float), 257*inp_bytes*sizeof(float)); + } + + if (!ggml_backend_graph_compute(model.backend, gf)) { + fprintf(stderr, "%s: ggml_backend_graph_compute() failed\n", __func__); + return false; + } + + struct ggml_tensor * target_label_probs = ggml_graph_get_tensor(gf, "target_label_probs"); + + // print probabilities for the top labels of each file + for (size_t i = 0; i < fnames.size(); i++) { + std::vector probs(hparams.n_label); + ggml_backend_tensor_get(target_label_probs, probs.data(), hparams.n_label*i*sizeof(float), hparams.n_label*sizeof(float)); + + // sort the probabilities + std::vector idx(hparams.n_label); + std::iota(idx.begin(), idx.end(), 0); + std::sort(idx.begin(), idx.end(), [&probs](int i1, int i2) { return probs[i1] > probs[i2]; }); + + // print the top labels + const int top_n = 5; + printf("%-30s: ", fnames[i].c_str()); + for (int j = 0; j < top_n; j++) { + printf("%s (%.2f%%) ", magika_labels[idx[j]], probs[idx[j]]*100); + } + printf("\n"); + } + + return true; +} + +int main(int argc, const char ** argv) { + if (argc < 3) { + fprintf(stderr, "usage: %s [ ...]\n", argv[0]); + return 1; + } + + const char * model_fname = argv[1]; + std::vector fnames; + for (int i = 2; i < argc; i++) { + fnames.push_back(argv[i]); + } + + magika_model model; + if (!magika_model_load(model_fname, model)) { + fprintf(stderr, "magika_model_load() failed\n"); + return 1; + } + + magika_eval(model, fnames); + + return 0; +} diff --git a/src/ggml.c b/src/ggml.c index 23c5e695..0fe1f4b5 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -1576,9 +1576,15 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { uint16_t t; for (int i = 0; i < n; ++i) { - ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]); + if (x[i] <= -10.0f) { + y[i] = 0.0f; + } else if (x[i] >= 10.0f) { + y[i] = x[i]; + } else { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]); + } } } #else @@ -5746,11 +5752,13 @@ struct ggml_tensor * ggml_pool_1d( is_node = true; } - const int64_t ne[2] = { + const int64_t ne[4] = { ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), a->ne[1], + a->ne[2], + a->ne[3], }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); int32_t params[] = { op, k0, s0, p0 }; ggml_set_op_params(result, params, sizeof(params)); @@ -15031,9 +15039,10 @@ static void ggml_compute_forward_map_custom1( return; } - struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) dst->op_params; + struct ggml_map_custom1_op_params p; + memcpy(&p, dst->op_params, sizeof(p)); - p->fun(dst, a, params->ith, params->nth, p->userdata); + p.fun(dst, a, params->ith, params->nth, p.userdata); } // ggml_compute_forward_map_custom2 @@ -15049,9 +15058,10 @@ static void ggml_compute_forward_map_custom2( return; } - struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) dst->op_params; + struct ggml_map_custom2_op_params p; + memcpy(&p, dst->op_params, sizeof(p)); - p->fun(dst, a, b, params->ith, params->nth, p->userdata); + p.fun(dst, a, b, params->ith, params->nth, p.userdata); } // ggml_compute_forward_map_custom3 @@ -15068,9 +15078,10 @@ static void ggml_compute_forward_map_custom3( return; } - struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) dst->op_params; + struct ggml_map_custom3_op_params p; + memcpy(&p, dst->op_params, sizeof(p)); - p->fun(dst, a, b, c, params->ith, params->nth, p->userdata); + p.fun(dst, a, b, c, params->ith, params->nth, p.userdata); } // ggml_compute_forward_cross_entropy_loss @@ -17336,29 +17347,32 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_MAP_CUSTOM1: { - struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) node->op_params; - if (p->n_tasks == GGML_N_TASKS_MAX) { + struct ggml_map_custom1_op_params p; + memcpy(&p, node->op_params, sizeof(p)); + if (p.n_tasks == GGML_N_TASKS_MAX) { n_tasks = n_threads; } else { - n_tasks = MIN(p->n_tasks, n_threads); + n_tasks = MIN(p.n_tasks, n_threads); } } break; case GGML_OP_MAP_CUSTOM2: { - struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) node->op_params; - if (p->n_tasks == GGML_N_TASKS_MAX) { + struct ggml_map_custom2_op_params p; + memcpy(&p, node->op_params, sizeof(p)); + if (p.n_tasks == GGML_N_TASKS_MAX) { n_tasks = n_threads; } else { - n_tasks = MIN(p->n_tasks, n_threads); + n_tasks = MIN(p.n_tasks, n_threads); } } break; case GGML_OP_MAP_CUSTOM3: { - struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) node->op_params; - if (p->n_tasks == GGML_N_TASKS_MAX) { + struct ggml_map_custom3_op_params p; + memcpy(&p, node->op_params, sizeof(p)); + if (p.n_tasks == GGML_N_TASKS_MAX) { n_tasks = n_threads; } else { - n_tasks = MIN(p->n_tasks, n_threads); + n_tasks = MIN(p.n_tasks, n_threads); } } break; case GGML_OP_CROSS_ENTROPY_LOSS: