]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
add google magika inference example (#748)
authorslaren <redacted>
Sun, 25 Feb 2024 19:41:35 +0000 (20:41 +0100)
committerGitHub <redacted>
Sun, 25 Feb 2024 19:41:35 +0000 (20:41 +0100)
* add magika inference example

* ggml : fix unaligned accesses in custom ops

* ggml : fix FP32 GELU for values that exceed the FP16 range

* use ggml_pool_1d

* add README

* Update README.md

* pad inputs if the files are too small

* cleanup

ggml-ci

examples/CMakeLists.txt
examples/magika/CMakeLists.txt [new file with mode: 0644]
examples/magika/README.md [new file with mode: 0644]
examples/magika/convert.py [new file with mode: 0644]
examples/magika/main.cpp [new file with mode: 0644]
src/ggml.c

index 5a268dca1be9a3517f992fd389983b36503585d9..d3bf460b014ac5942b87badd9abeb577889045e1 100644 (file)
@@ -24,3 +24,4 @@ add_subdirectory(whisper)
 add_subdirectory(mnist)
 add_subdirectory(sam)
 add_subdirectory(yolo)
+add_subdirectory(magika)
diff --git a/examples/magika/CMakeLists.txt b/examples/magika/CMakeLists.txt
new file mode 100644 (file)
index 0000000..5543237
--- /dev/null
@@ -0,0 +1,21 @@
+#
+# magika
+
+set(TEST_TARGET magika)
+add_executable(${TEST_TARGET} main.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
+
+#
+# For GPU offloading
+
+if (GGML_CUBLAS)
+    add_compile_definitions(GGML_USE_CUBLAS)
+endif()
+
+if (GGML_CLBLAST)
+    add_compile_definitions(GGML_USE_CLBLAST)
+endif()
+
+if (GGML_METAL)
+    add_compile_definitions(GGML_USE_METAL)
+endif()
diff --git a/examples/magika/README.md b/examples/magika/README.md
new file mode 100644 (file)
index 0000000..8e1ca27
--- /dev/null
@@ -0,0 +1,23 @@
+# Google Magika inference
+
+Simple example that shows how to use GGML for inference with the [Google Magika](https://github.com/google/magika) file type detection model.
+
+### Usage
+
+- Obtain the Magika model in H5 format
+  - Pinned version: https://github.com/google/magika/blob/4460acb5d3f86807c3b53223229dee2afa50c025/assets_generation/models/standard_v1/model.h5
+- Use `convert.py` to convert the model to gguf format:
+```sh
+  $ python examples/magika/convert.py /path/to/model.h5
+```
+- Invoke the program with the model file and a list of files to identify:
+```sh
+  $ build/bin/magika model.h5.gguf examples/sam/example.jpg examples/magika/convert.py README.md src/ggml.c /bin/gcc write.exe jfk.wav
+  examples/sam/example.jpg      : jpeg (100.00%) pptx (0.00%) smali (0.00%) shell (0.00%) sevenzip (0.00%)
+  examples/magika/convert.py    : python (99.99%) javascript (0.00%) txt (0.00%) asm (0.00%) scala (0.00%)
+  README.md                     : markdown (100.00%) txt (0.00%) yaml (0.00%) ppt (0.00%) shell (0.00%)
+  src/ggml.c                    : c (99.95%) txt (0.04%) asm (0.01%) yaml (0.00%) html (0.00%)
+  /bin/gcc                      : elf (99.98%) odex (0.02%) pptx (0.00%) smali (0.00%) shell (0.00%)
+  write.exe                     : pebin (100.00%) ppt (0.00%) smali (0.00%) shell (0.00%) sevenzip (0.00%)
+  jfk.wav                       : wav (100.00%) ppt (0.00%) shell (0.00%) sevenzip (0.00%) scala (0.00%)
+```
diff --git a/examples/magika/convert.py b/examples/magika/convert.py
new file mode 100644 (file)
index 0000000..b901a34
--- /dev/null
@@ -0,0 +1,32 @@
+import sys
+from tensorflow import keras
+import gguf
+
+def convert(model_name):
+    model = keras.models.load_model(model_name, compile=False)
+    gguf_model_name = model_name + ".gguf"
+    gguf_writer = gguf.GGUFWriter(gguf_model_name, "magika")
+
+    for layer in model.layers:
+        # export layers with weights
+        if layer.weights:
+            for weight in layer.weights:
+                print(f"  [{weight.name}] {weight.shape} {weight.dtype}")
+                weight_data = weight.numpy()
+                gguf_writer.add_tensor(weight.name, weight_data.T)
+
+
+    gguf_writer.write_header_to_file()
+    gguf_writer.write_kv_data_to_file()
+    gguf_writer.write_tensors_to_file()
+    gguf_writer.close()
+    print("Model converted and saved to '{}'".format(gguf_model_name))
+
+
+if __name__ == '__main__':
+    if len(sys.argv) > 1:
+        model_file = sys.argv[1]
+    else:
+        model_file = "model.h5"
+
+    convert(model_file)
diff --git a/examples/magika/main.cpp b/examples/magika/main.cpp
new file mode 100644 (file)
index 0000000..d55b796
--- /dev/null
@@ -0,0 +1,371 @@
+#include "ggml/ggml.h"
+#include "ggml/ggml-alloc.h"
+#include "ggml/ggml-backend.h"
+#include <algorithm>
+#include <cmath>
+#include <numeric>
+#include <stdexcept>
+#include <string>
+#include <vector>
+
+static const char * magika_labels[] = {
+    "ai",                 "apk",                "appleplist",         "asm",                "asp",
+    "batch",              "bmp",                "bzip",               "c",                  "cab",
+    "cat",                "chm",                "coff",               "crx",                "cs",
+    "css",                "csv",                "deb",                "dex",                "dmg",
+    "doc",                "docx",               "elf",                "emf",                "eml",
+    "epub",               "flac",               "gif",                "go",                 "gzip",
+    "hlp",                "html",               "ico",                "ini",                "internetshortcut",
+    "iso",                "jar",                "java",               "javabytecode",       "javascript",
+    "jpeg",               "json",               "latex",              "lisp",               "lnk",
+    "m3u",                "macho",              "makefile",           "markdown",           "mht",
+    "mp3",                "mp4",                "mscompress",         "msi",                "mum",
+    "odex",               "odp",                "ods",                "odt",                "ogg",
+    "outlook",            "pcap",               "pdf",                "pebin",              "pem",
+    "perl",               "php",                "png",                "postscript",         "powershell",
+    "ppt",                "pptx",               "python",             "pythonbytecode",     "rar",
+    "rdf",                "rpm",                "rst",                "rtf",                "ruby",
+    "rust",               "scala",              "sevenzip",           "shell",              "smali",
+    "sql",                "squashfs",           "svg",                "swf",                "symlinktext",
+    "tar",                "tga",                "tiff",               "torrent",            "ttf",
+    "txt",                "unknown",            "vba",                "wav",                "webm",
+    "webp",               "winregistry",        "wmf",                "xar",                "xls",
+    "xlsb",               "xlsx",               "xml",                "xpi",                "xz",
+    "yaml",               "zip",                "zlibstream"
+};
+
+struct magika_hparams {
+    const int block_size = 4096;
+    const int beg_size = 512;
+    const int mid_size = 512;
+    const int end_size = 512;
+    const int min_file_size_for_dl = 16;
+    const int n_label = 113;
+    const float f_norm_eps = 0.001f;
+    const int padding_token = 256;
+};
+
+struct magika_model {
+    ~magika_model() {
+        ggml_backend_buffer_free(buf_w);
+        ggml_backend_free(backend);
+        ggml_free(ctx_w);
+    }
+
+    magika_hparams hparams;
+
+    struct ggml_tensor * dense_w;
+    struct ggml_tensor * dense_b;
+
+    struct ggml_tensor * layer_norm_gamma;
+    struct ggml_tensor * layer_norm_beta;
+
+    struct ggml_tensor * dense_1_w;
+    struct ggml_tensor * dense_1_b;
+
+    struct ggml_tensor * dense_2_w;
+    struct ggml_tensor * dense_2_b;
+
+    struct ggml_tensor * layer_norm_1_gamma;
+    struct ggml_tensor * layer_norm_1_beta;
+
+    struct ggml_tensor * target_label_w;
+    struct ggml_tensor * target_label_b;
+
+    ggml_backend_t backend = ggml_backend_cpu_init();
+    ggml_backend_buffer_t buf_w = nullptr;
+    struct ggml_context * ctx_w = nullptr;
+};
+
+struct ggml_tensor * checked_get_tensor(struct ggml_context * ctx, const char * name) {
+    struct ggml_tensor * tensor = ggml_get_tensor(ctx, name);
+    if (!tensor) {
+        fprintf(stderr, "%s: tensor '%s' not found\n", __func__, name);
+        throw std::runtime_error("ggml_get_tensor() failed");
+    }
+    return tensor;
+}
+
+bool magika_model_load(const std::string & fname, magika_model & model) {
+    auto & ctx = model.ctx_w;
+
+    struct gguf_init_params params = {
+        /*.no_alloc   =*/ true,
+        /*.ctx        =*/ &ctx,
+    };
+
+    struct gguf_context * ctx_gguf = gguf_init_from_file(fname.c_str(), params);
+    if (!ctx_gguf) {
+        fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
+        return false;
+    }
+
+    model.buf_w = ggml_backend_alloc_ctx_tensors(ctx, model.backend);
+    if (!model.buf_w) {
+        fprintf(stderr, "%s: ggml_backend_alloc_ctx_tensors() failed\n", __func__);
+        gguf_free(ctx_gguf);
+        return false;
+    }
+
+    try {
+        model.dense_w = checked_get_tensor(ctx, "dense/kernel:0");
+        model.dense_b = checked_get_tensor(ctx, "dense/bias:0");
+
+        model.layer_norm_gamma = checked_get_tensor(ctx, "layer_normalization/gamma:0");
+        model.layer_norm_beta  = checked_get_tensor(ctx, "layer_normalization/beta:0");
+
+        model.dense_1_w = checked_get_tensor(ctx, "dense_1/kernel:0");
+        model.dense_1_b = checked_get_tensor(ctx, "dense_1/bias:0");
+
+        model.dense_2_w = checked_get_tensor(ctx, "dense_2/kernel:0");
+        model.dense_2_b = checked_get_tensor(ctx, "dense_2/bias:0");
+
+        model.layer_norm_1_gamma = checked_get_tensor(ctx, "layer_normalization_1/gamma:0");
+        model.layer_norm_1_beta  = checked_get_tensor(ctx, "layer_normalization_1/beta:0");
+
+        model.target_label_w = checked_get_tensor(ctx, "target_label/kernel:0");
+        model.target_label_b = checked_get_tensor(ctx, "target_label/bias:0");
+    } catch (const std::exception & e) {
+        fprintf(stderr, "%s: %s\n", __func__, e.what());
+        gguf_free(ctx_gguf);
+        return false;
+    }
+
+    FILE * f = fopen(fname.c_str(), "rb");
+    if (!f) {
+        fprintf(stderr, "%s: fopen() failed\n", __func__);
+        gguf_free(ctx_gguf);
+        return false;
+    }
+
+    const int n_tensors = gguf_get_n_tensors(ctx_gguf);
+
+    for (int i = 0; i < n_tensors; i++) {
+        const char * name = gguf_get_tensor_name(ctx_gguf, i);
+        struct ggml_tensor * tensor = ggml_get_tensor(ctx, name);
+        size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i);
+
+        //printf("%-30s: [%3ld, %3ld, %3ld, %3ld] %s\n",
+        //    name,
+        //    tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3],
+        //    ggml_type_name(tensor->type));
+
+        std::vector<uint8_t> buf(ggml_nbytes(tensor));
+        if (fseek(f, offs, SEEK_SET) != 0) {
+            fprintf(stderr, "%s: fseek() failed\n", __func__);
+            gguf_free(ctx_gguf);
+            fclose(f);
+            return false;
+        }
+
+        if (fread(buf.data(), 1, buf.size(), f) != buf.size()) {
+            fprintf(stderr, "%s: fread() failed\n", __func__);
+            gguf_free(ctx_gguf);
+            fclose(f);
+            return false;
+        }
+
+        ggml_backend_tensor_set(tensor, buf.data(), 0, buf.size());
+    }
+
+    fclose(f);
+
+    gguf_free(ctx_gguf);
+
+    return true;
+}
+
+struct ggml_cgraph * magika_graph(
+    const magika_model & model,
+    const int n_files) {
+
+    const auto & hparams = model.hparams;
+
+    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
+    static std::vector<uint8_t> buf(buf_size);
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ buf_size,
+        /*.mem_buffer =*/ buf.data(),
+        /*.no_alloc   =*/ true,
+    };
+
+    struct ggml_context * ctx = ggml_init(params);
+
+    struct ggml_cgraph * gf = ggml_new_graph(ctx);
+
+    struct ggml_tensor * input = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 257, 1536, n_files); // one-hot
+    ggml_set_name(input, "input");
+    ggml_set_input(input);
+
+    struct ggml_tensor * cur;
+
+    // dense
+    cur = ggml_mul_mat(ctx, model.dense_w, input);
+    cur = ggml_add(ctx, cur, model.dense_b); // [128, 1536, n_files]
+    cur = ggml_gelu(ctx, cur);
+
+    // reshape
+    cur = ggml_reshape_3d(ctx, cur, 512, 384, n_files); // [384, 512, n_files]
+    cur = ggml_cont(ctx, ggml_transpose(ctx, cur));
+
+    // layer normalization
+    cur = ggml_norm(ctx, cur, hparams.f_norm_eps);
+    cur = ggml_mul(ctx, cur, model.layer_norm_gamma); // [384, 512, n_files]
+    cur = ggml_add(ctx, cur, model.layer_norm_beta);  // [384, 512, n_files]
+
+    // dense_1
+    cur = ggml_cont(ctx, ggml_transpose(ctx, cur));
+    cur = ggml_mul_mat(ctx, model.dense_1_w, cur);
+    cur = ggml_add(ctx, cur, model.dense_1_b); // [256, 384, n_files]
+    cur = ggml_gelu(ctx, cur);
+
+    // dense_2
+    cur = ggml_mul_mat(ctx, model.dense_2_w, cur);
+    cur = ggml_add(ctx, cur, model.dense_2_b); // [256, 384, n_files]
+    cur = ggml_gelu(ctx, cur);
+
+    // global_max_pooling1d
+    cur = ggml_cont(ctx, ggml_transpose(ctx, cur)); // [384, 256, n_files]
+    cur = ggml_pool_1d(ctx, cur, GGML_OP_POOL_MAX, 384, 384, 0); // [1, 256, n_files]
+    cur = ggml_reshape_2d(ctx, cur, 256, n_files); // [256, n_files]
+
+    // layer normalization 1
+    cur = ggml_norm(ctx, cur, hparams.f_norm_eps);
+    cur = ggml_mul(ctx, cur, model.layer_norm_1_gamma); // [256, n_files]
+    cur = ggml_add(ctx, cur, model.layer_norm_1_beta);  // [256, n_files]
+
+    // target_label
+    cur = ggml_mul_mat(ctx, model.target_label_w, cur);
+    cur = ggml_add(ctx, cur, model.target_label_b); // [n_label, n_files]
+    cur = ggml_soft_max(ctx, cur); // [n_label, n_files]
+    ggml_set_name(cur, "target_label_probs");
+    ggml_set_output(cur);
+
+    ggml_build_forward_expand(gf, cur);
+
+    return gf;
+}
+
+bool magika_eval(
+    struct magika_model & model,
+    const std::vector<std::string> & fnames) {
+
+    const auto & hparams = model.hparams;
+
+    static ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));
+
+    struct ggml_cgraph * gf = magika_graph(model, fnames.size());
+
+    if (!ggml_gallocr_alloc_graph(alloc, gf)) {
+        fprintf(stderr, "%s: ggml_gallocr_alloc_graph() failed\n", __func__);
+        return false;
+    }
+
+    struct ggml_tensor * input = ggml_graph_get_tensor(gf, "input");
+
+    for (size_t i = 0; i < fnames.size(); i++) {
+        FILE * f = fopen(fnames[i].c_str(), "rb");
+        if (!f) {
+            fprintf(stderr, "%s: fopen() failed\n", __func__);
+            return false;
+        }
+        fseek(f, 0, SEEK_END);
+        long fsize = ftell(f);
+
+        // the buffer is padded with the padding_token if the file is smaller than the block size
+        std::vector<int> buf(1536, hparams.padding_token);
+        std::vector<uint8_t> read_buf(std::max(hparams.beg_size, std::max(hparams.mid_size, hparams.end_size)));
+
+        // read beg
+        fseek(f, 0, SEEK_SET);
+        int n_read = fread(read_buf.data(), 1, hparams.beg_size, f);
+        for (int j = 0; j < n_read; j++) {
+            // pad at the end
+            buf[j] = read_buf[j];
+        }
+
+        // read mid
+        long mid_offs = std::max(0L, (fsize - hparams.mid_size) / 2);
+        fseek(f, mid_offs, SEEK_SET);
+        n_read = fread(read_buf.data(), 1, hparams.mid_size, f);
+        for (int j = 0; j < n_read; j++) {
+            // pad at both ends
+            long mid_idx = hparams.beg_size + (hparams.mid_size / 2) - n_read / 2 + j;
+            buf[mid_idx] = read_buf[j];
+        }
+
+        // read end
+        long end_offs = std::max(0L, fsize - hparams.end_size);
+        fseek(f, end_offs, SEEK_SET);
+        n_read = fread(read_buf.data(), 1, hparams.end_size, f);
+        for (int j = 0; j < n_read; j++) {
+            // pad at the beginning
+            int end_idx = hparams.beg_size + hparams.mid_size + hparams.end_size - n_read + j;
+            buf[end_idx] = read_buf[j];
+        }
+
+        fclose(f);
+
+        const size_t inp_bytes = hparams.beg_size + hparams.mid_size + hparams.end_size;
+
+        // convert to one-hot
+        std::vector<float> one_hot(257*inp_bytes);
+        for (size_t j = 0; j < inp_bytes; j++) {
+            one_hot[257*j + buf[j]] = 1.0f;
+        }
+
+        ggml_backend_tensor_set(input, one_hot.data(), 257*inp_bytes*i*sizeof(float), 257*inp_bytes*sizeof(float));
+    }
+
+    if (!ggml_backend_graph_compute(model.backend, gf)) {
+        fprintf(stderr, "%s: ggml_backend_graph_compute() failed\n", __func__);
+        return false;
+    }
+
+    struct ggml_tensor * target_label_probs = ggml_graph_get_tensor(gf, "target_label_probs");
+
+    // print probabilities for the top labels of each file
+    for (size_t i = 0; i < fnames.size(); i++) {
+        std::vector<float> probs(hparams.n_label);
+        ggml_backend_tensor_get(target_label_probs, probs.data(), hparams.n_label*i*sizeof(float), hparams.n_label*sizeof(float));
+
+        // sort the probabilities
+        std::vector<int> idx(hparams.n_label);
+        std::iota(idx.begin(), idx.end(), 0);
+        std::sort(idx.begin(), idx.end(), [&probs](int i1, int i2) { return probs[i1] > probs[i2]; });
+
+        // print the top labels
+        const int top_n = 5;
+        printf("%-30s: ", fnames[i].c_str());
+        for (int j = 0; j < top_n; j++) {
+            printf("%s (%.2f%%) ", magika_labels[idx[j]], probs[idx[j]]*100);
+        }
+        printf("\n");
+    }
+
+    return true;
+}
+
+int main(int argc, const char ** argv) {
+    if (argc < 3) {
+        fprintf(stderr, "usage: %s <model> <file1> [<file2> ...]\n", argv[0]);
+        return 1;
+    }
+
+    const char * model_fname = argv[1];
+    std::vector<std::string> fnames;
+    for (int i = 2; i < argc; i++) {
+        fnames.push_back(argv[i]);
+    }
+
+    magika_model model;
+    if (!magika_model_load(model_fname, model)) {
+        fprintf(stderr, "magika_model_load() failed\n");
+        return 1;
+    }
+
+    magika_eval(model, fnames);
+
+    return 0;
+}
index 23c5e6950fe8e2003089daed9e1f2387ea6d6a14..0fe1f4b52bf23c28f0e67661e9457cc87273a0d6 100644 (file)
@@ -1576,9 +1576,15 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp
 inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
     uint16_t t;
     for (int i = 0; i < n; ++i) {
-        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
-        memcpy(&t, &fp16, sizeof(uint16_t));
-        y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
+        if (x[i] <= -10.0f) {
+            y[i] = 0.0f;
+        } else if (x[i] >= 10.0f) {
+            y[i] = x[i];
+        } else {
+            ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+            memcpy(&t, &fp16, sizeof(uint16_t));
+            y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
+        }
     }
 }
 #else
@@ -5746,11 +5752,13 @@ struct ggml_tensor * ggml_pool_1d(
         is_node = true;
     }
 
-    const int64_t ne[2] = {
+    const int64_t ne[4] = {
         ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
         a->ne[1],
+        a->ne[2],
+        a->ne[3],
     };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     int32_t params[] = { op, k0, s0, p0 };
     ggml_set_op_params(result, params, sizeof(params));
@@ -15031,9 +15039,10 @@ static void ggml_compute_forward_map_custom1(
         return;
     }
 
-    struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) dst->op_params;
+    struct ggml_map_custom1_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
 
-    p->fun(dst, a, params->ith, params->nth, p->userdata);
+    p.fun(dst, a, params->ith, params->nth, p.userdata);
 }
 
 // ggml_compute_forward_map_custom2
@@ -15049,9 +15058,10 @@ static void ggml_compute_forward_map_custom2(
         return;
     }
 
-    struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) dst->op_params;
+    struct ggml_map_custom2_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
 
-    p->fun(dst, a, b, params->ith, params->nth, p->userdata);
+    p.fun(dst, a, b, params->ith, params->nth, p.userdata);
 }
 
 // ggml_compute_forward_map_custom3
@@ -15068,9 +15078,10 @@ static void ggml_compute_forward_map_custom3(
         return;
     }
 
-    struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) dst->op_params;
+    struct ggml_map_custom3_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
 
-    p->fun(dst, a, b, c, params->ith, params->nth, p->userdata);
+    p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
 }
 
 // ggml_compute_forward_cross_entropy_loss
@@ -17336,29 +17347,32 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         case GGML_OP_MAP_CUSTOM1:
             {
-                struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) node->op_params;
-                if (p->n_tasks == GGML_N_TASKS_MAX) {
+                struct ggml_map_custom1_op_params p;
+                memcpy(&p, node->op_params, sizeof(p));
+                if (p.n_tasks == GGML_N_TASKS_MAX) {
                     n_tasks = n_threads;
                 } else {
-                    n_tasks = MIN(p->n_tasks, n_threads);
+                    n_tasks = MIN(p.n_tasks, n_threads);
                 }
             } break;
         case GGML_OP_MAP_CUSTOM2:
             {
-                struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) node->op_params;
-                if (p->n_tasks == GGML_N_TASKS_MAX) {
+                struct ggml_map_custom2_op_params p;
+                memcpy(&p, node->op_params, sizeof(p));
+                if (p.n_tasks == GGML_N_TASKS_MAX) {
                     n_tasks = n_threads;
                 } else {
-                    n_tasks = MIN(p->n_tasks, n_threads);
+                    n_tasks = MIN(p.n_tasks, n_threads);
                 }
             } break;
         case GGML_OP_MAP_CUSTOM3:
             {
-                struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) node->op_params;
-                if (p->n_tasks == GGML_N_TASKS_MAX) {
+                struct ggml_map_custom3_op_params p;
+                memcpy(&p, node->op_params, sizeof(p));
+                if (p.n_tasks == GGML_N_TASKS_MAX) {
                     n_tasks = n_threads;
                 } else {
-                    n_tasks = MIN(p->n_tasks, n_threads);
+                    n_tasks = MIN(p.n_tasks, n_threads);
                 }
             } break;
         case GGML_OP_CROSS_ENTROPY_LOSS: