]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Initial release
authorGeorgi Gerganov <redacted>
Fri, 10 Mar 2023 18:40:58 +0000 (20:40 +0200)
committerGeorgi Gerganov <redacted>
Fri, 10 Mar 2023 18:56:40 +0000 (20:56 +0200)
.gitignore [new file with mode: 0644]
Makefile [new file with mode: 0644]
convert-pth-to-ggml.py [new file with mode: 0644]
ggml.c [new file with mode: 0644]
ggml.h [new file with mode: 0644]
main.cpp [new file with mode: 0644]
quantize.cpp [new file with mode: 0644]
utils.cpp [new file with mode: 0644]
utils.h [new file with mode: 0644]

diff --git a/.gitignore b/.gitignore
new file mode 100644 (file)
index 0000000..4879be4
--- /dev/null
@@ -0,0 +1,21 @@
+*.o
+*.a
+.cache/
+.vs/
+.vscode/
+.DS_Store
+
+build/
+build-em/
+build-debug/
+build-release/
+build-static/
+build-no-accel/
+build-sanitize-addr/
+build-sanitize-thread/
+
+/main
+/quantize
+
+arm_neon.h
+compile_commands.json
diff --git a/Makefile b/Makefile
new file mode 100644 (file)
index 0000000..862d02b
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,203 @@
+ifndef UNAME_S
+UNAME_S := $(shell uname -s)
+endif
+
+ifndef UNAME_P
+UNAME_P := $(shell uname -p)
+endif
+
+ifndef UNAME_M
+UNAME_M := $(shell uname -m)
+endif
+
+CCV := $(shell $(CC) --version | head -n 1)
+CXXV := $(shell $(CXX) --version | head -n 1)
+
+# Mac OS + Arm can report x86_64
+# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
+ifeq ($(UNAME_S),Darwin)
+       ifneq ($(UNAME_P),arm)
+               SYSCTL_M := $(shell sysctl -n hw.optional.arm64)
+               ifeq ($(SYSCTL_M),1)
+                       # UNAME_P := arm
+                       # UNAME_M := arm64
+                       warn := $(warning Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-1282546789)
+               endif
+       endif
+endif
+
+#
+# Compile flags
+#
+
+CFLAGS   = -I.              -O3 -DNDEBUG -std=c11   -fPIC
+CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
+LDFLAGS  =
+
+# OS specific
+# TODO: support Windows
+ifeq ($(UNAME_S),Linux)
+       CFLAGS   += -pthread
+       CXXFLAGS += -pthread
+endif
+ifeq ($(UNAME_S),Darwin)
+       CFLAGS   += -pthread
+       CXXFLAGS += -pthread
+endif
+ifeq ($(UNAME_S),FreeBSD)
+       CFLAGS   += -pthread
+       CXXFLAGS += -pthread
+endif
+ifeq ($(UNAME_S),Haiku)
+       CFLAGS   += -pthread
+       CXXFLAGS += -pthread
+endif
+
+# Architecture specific
+# TODO: probably these flags need to be tweaked on some architectures
+#       feel free to update the Makefile for your architecture and send a pull request or issue
+ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
+       ifeq ($(UNAME_S),Darwin)
+               CFLAGS += -mf16c
+               AVX1_M := $(shell sysctl machdep.cpu.features)
+               ifneq (,$(findstring FMA,$(AVX1_M)))
+                       CFLAGS += -mfma
+               endif
+               ifneq (,$(findstring AVX1.0,$(AVX1_M)))
+                       CFLAGS += -mavx
+               endif
+               AVX2_M := $(shell sysctl machdep.cpu.leaf7_features)
+               ifneq (,$(findstring AVX2,$(AVX2_M)))
+                       CFLAGS += -mavx2
+               endif
+       else ifeq ($(UNAME_S),Linux)
+               AVX1_M := $(shell grep "avx " /proc/cpuinfo)
+               ifneq (,$(findstring avx,$(AVX1_M)))
+                       CFLAGS += -mavx
+               endif
+               AVX2_M := $(shell grep "avx2 " /proc/cpuinfo)
+               ifneq (,$(findstring avx2,$(AVX2_M)))
+                       CFLAGS += -mavx2
+               endif
+               FMA_M := $(shell grep "fma " /proc/cpuinfo)
+               ifneq (,$(findstring fma,$(FMA_M)))
+                       CFLAGS += -mfma
+               endif
+               F16C_M := $(shell grep "f16c " /proc/cpuinfo)
+               ifneq (,$(findstring f16c,$(F16C_M)))
+                       CFLAGS += -mf16c
+               endif
+               SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
+               ifneq (,$(findstring sse3,$(SSE3_M)))
+                       CFLAGS += -msse3
+               endif
+       else ifeq ($(UNAME_S),Haiku)
+               AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
+               ifneq (,$(findstring avx,$(AVX1_M)))
+                       CFLAGS += -mavx
+               endif
+               AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ")
+               ifneq (,$(findstring avx2,$(AVX2_M)))
+                       CFLAGS += -mavx2
+               endif
+               FMA_M := $(shell sysinfo -cpu | grep "FMA ")
+               ifneq (,$(findstring fma,$(FMA_M)))
+                       CFLAGS += -mfma
+               endif
+               F16C_M := $(shell sysinfo -cpu | grep "F16C ")
+               ifneq (,$(findstring f16c,$(F16C_M)))
+                       CFLAGS += -mf16c
+               endif
+       else
+               CFLAGS += -mfma -mf16c -mavx -mavx2
+       endif
+endif
+ifeq ($(UNAME_M),amd64)
+       CFLAGS += -mavx -mavx2 -mfma -mf16c
+endif
+ifneq ($(filter ppc64%,$(UNAME_M)),)
+       POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
+       ifneq (,$(findstring POWER9,$(POWER9_M)))
+               CFLAGS += -mpower9-vector
+       endif
+       # Require c++23's std::byteswap for big-endian support.
+       ifeq ($(UNAME_M),ppc64)
+               CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN
+       endif
+endif
+ifndef WHISPER_NO_ACCELERATE
+       # Mac M1 - include Accelerate framework
+       ifeq ($(UNAME_S),Darwin)
+               CFLAGS  += -DGGML_USE_ACCELERATE
+               LDFLAGS += -framework Accelerate
+       endif
+endif
+ifdef WHISPER_OPENBLAS
+       CFLAGS  += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
+       LDFLAGS += -lopenblas
+endif
+ifdef WHISPER_GPROF
+       CFLAGS   += -pg
+       CXXFLAGS += -pg
+endif
+ifneq ($(filter aarch64%,$(UNAME_M)),)
+       CFLAGS += -mcpu=native
+       CXXFLAGS += -mcpu=native
+endif
+ifneq ($(filter armv6%,$(UNAME_M)),)
+       # Raspberry Pi 1, 2, 3
+       CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
+endif
+ifneq ($(filter armv7%,$(UNAME_M)),)
+       # Raspberry Pi 4
+       CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
+endif
+ifneq ($(filter armv8%,$(UNAME_M)),)
+       # Raspberry Pi 4
+       CFLAGS += -mfp16-format=ieee -mno-unaligned-access
+endif
+
+#
+# Print build information
+#
+
+$(info I llama.cpp build info: )
+$(info I UNAME_S:  $(UNAME_S))
+$(info I UNAME_P:  $(UNAME_P))
+$(info I UNAME_M:  $(UNAME_M))
+$(info I CFLAGS:   $(CFLAGS))
+$(info I CXXFLAGS: $(CXXFLAGS))
+$(info I LDFLAGS:  $(LDFLAGS))
+$(info I CC:       $(CCV))
+$(info I CXX:      $(CXXV))
+$(info )
+
+default: main quantize
+
+#
+# Build library
+#
+
+ggml.o: ggml.c ggml.h
+       $(CC)  $(CFLAGS)   -c ggml.c -o ggml.o
+
+utils.o: utils.cpp utils.h
+       $(CXX) $(CXXFLAGS) -c utils.cpp -o utils.o
+
+clean:
+       rm -f *.o main quantize
+
+main: main.cpp ggml.o utils.o
+       $(CXX) $(CXXFLAGS) main.cpp ggml.o utils.o -o main $(LDFLAGS)
+       ./main -h
+
+quantize: quantize.cpp ggml.o utils.o
+       $(CXX) $(CXXFLAGS) quantize.cpp ggml.o utils.o -o quantize $(LDFLAGS)
+
+#
+# Tests
+#
+
+.PHONY: tests
+tests:
+       bash ./tests/run-tests.sh
diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py
new file mode 100644 (file)
index 0000000..d0a187c
--- /dev/null
@@ -0,0 +1,136 @@
+# Convert a LLaMA model checkpoint to a ggml compatible file
+#
+# Load the model using Torch
+# Iterate over all variables and write them to a binary file.
+#
+# For each variable, write the following:
+#   - Number of dimensions (int)
+#   - Name length (int)
+#   - Dimensions (int[n_dims])
+#   - Name (char[name_length])
+#   - Data (float[n_dims])
+#
+# By default, the bigger matrices are converted to 16-bit floats.
+# This can be disabled by adding the "use-f32" CLI argument.
+#
+# At the start of the ggml file we write the model parameters
+# and vocabulary.
+#
+
+import sys
+import json
+import struct
+import numpy as np
+import torch
+
+from sentencepiece import SentencePieceProcessor
+
+if len(sys.argv) < 3:
+    print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n")
+    print("  ftype == 0 -> float32")
+    print("  ftype == 1 -> float16")
+    sys.exit(1)
+
+# output in the same directory as the model
+dir_model = sys.argv[1]
+fname_out = sys.argv[1] + "/ggml-model.bin"
+
+fname_hparams   = sys.argv[1] + "/params.json"
+fname_model     = sys.argv[1] + "/consolidated.00.pth"
+fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
+
+# possible data types
+#   ftype == 0 -> float32
+#   ftype == 1 -> float16
+#
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+ftype = 1
+if len(sys.argv) > 2:
+    ftype = int(sys.argv[2])
+    if ftype < 0 or ftype > 1:
+        print("Invalid ftype: " + str(ftype))
+        sys.exit(1)
+    fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
+
+with open(fname_hparams, "r") as f:
+    hparams = json.load(f)
+
+tokenizer = SentencePieceProcessor(fname_tokenizer)
+
+hparams.update({"vocab_size": tokenizer.vocab_size()})
+
+print(hparams)
+
+model = torch.load(fname_model, map_location="cpu")
+
+fout = open(fname_out, "wb")
+
+fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
+fout.write(struct.pack("i", hparams["vocab_size"]))
+fout.write(struct.pack("i", hparams["dim"]))
+fout.write(struct.pack("i", hparams["multiple_of"]))
+fout.write(struct.pack("i", hparams["n_heads"]))
+fout.write(struct.pack("i", hparams["n_layers"]))
+fout.write(struct.pack("i", 64)) # rot
+fout.write(struct.pack("i", ftype))
+
+# Is this correct??
+for i in range(32000):
+    # TODO: this is probably wrong - not sure how this tokenizer works
+    text = tokenizer.decode([29889, i]).encode('utf-8')
+    # remove the first byte (it's always '.')
+    text = text[1:]
+    fout.write(struct.pack("i", len(text)))
+    fout.write(text)
+
+for k, v in model.items():
+    name = k
+    shape = v.shape
+
+    # skip layers.X.attention.inner_attention.rope.freqs
+    if name[-5:] == "freqs":
+        continue
+
+    print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
+
+    #data = tf.train.load_variable(dir_model, name).squeeze()
+    data = v.numpy().squeeze()
+    n_dims = len(data.shape);
+
+    # for efficiency - transpose some matrices
+    # "model/h.*/attn/c_attn/w"
+    # "model/h.*/attn/c_proj/w"
+    # "model/h.*/mlp/c_fc/w"
+    # "model/h.*/mlp/c_proj/w"
+    #if name[-14:] == "/attn/c_attn/w" or \
+    #   name[-14:] == "/attn/c_proj/w" or \
+    #   name[-11:] == "/mlp/c_fc/w" or \
+    #   name[-13:] == "/mlp/c_proj/w":
+    #    print("  Transposing")
+    #    data = data.transpose()
+
+    dshape = data.shape
+
+    # default type is fp16
+    ftype_cur = 1
+    if ftype == 0 or n_dims == 1:
+        print("  Converting to float32")
+        data = data.astype(np.float32)
+        ftype_cur = 0
+
+    # header
+    str = name.encode('utf-8')
+    fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+    for i in range(n_dims):
+        fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
+    fout.write(str);
+
+    # data
+    data.tofile(fout)
+
+fout.close()
+
+print("Done. Output file: " + fname_out)
+print("")
diff --git a/ggml.c b/ggml.c
new file mode 100644 (file)
index 0000000..ee3b0af
--- /dev/null
+++ b/ggml.c
@@ -0,0 +1,10324 @@
+#include "ggml.h"
+
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include <malloc.h> // using malloc.h with MSC/MINGW
+#elif !defined(__FreeBSD__)
+#include <alloca.h>
+#endif
+
+#include <assert.h>
+#include <time.h>
+#include <math.h>
+#include <stdlib.h>
+#include <string.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <float.h>
+
+// if C99 - static_assert is noop
+// ref: https://stackoverflow.com/a/53923785/4039976
+#ifndef static_assert
+#define static_assert(cond, msg) struct global_scope_noop_trick
+#endif
+
+#if defined _MSC_VER || defined(__MINGW32__)
+
+#if !defined(__MINGW32__)
+#include <Windows.h>
+#else
+// ref: https://github.com/ggerganov/whisper.cpp/issues/168
+#include <windows.h>
+#include <errno.h>
+#endif
+
+typedef volatile LONG atomic_int;
+typedef atomic_int atomic_bool;
+
+static void atomic_store(atomic_int* ptr, LONG val) {
+    InterlockedExchange(ptr, val);
+}
+static LONG atomic_load(atomic_int* ptr) {
+    return InterlockedCompareExchange(ptr, 0, 0);
+}
+static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
+    return InterlockedExchangeAdd(ptr, inc);
+}
+static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
+    return atomic_fetch_add(ptr, -(dec));
+}
+
+typedef HANDLE pthread_t;
+
+typedef DWORD thread_ret_t;
+static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
+    HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
+    if (handle == NULL)
+    {
+        return EAGAIN;
+    }
+
+    *out = handle;
+    return 0;
+}
+
+static int pthread_join(pthread_t thread, void* unused) {
+    return (int) WaitForSingleObject(thread, INFINITE);
+}
+
+static int sched_yield (void) {
+    Sleep (0);
+    return 0;
+}
+#else
+#include <pthread.h>
+#include <stdatomic.h>
+
+typedef void* thread_ret_t;
+#endif
+
+#ifdef __HAIKU__
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#endif
+
+/*#define GGML_PERF*/
+#define GGML_DEBUG 0
+#define GGML_GELU_FP16
+#define GGML_SILU_FP16
+
+#define GGML_SOFT_MAX_UNROLL 4
+#define GGML_VEC_DOT_UNROLL  2
+
+#ifdef GGML_USE_ACCELERATE
+// uncomment to use vDSP for soft max computation
+// note: not sure if it is actually faster
+//#define GGML_SOFT_MAX_ACCELERATE
+#endif
+
+#if UINTPTR_MAX == 0xFFFFFFFF
+    #define GGML_MEM_ALIGN 4
+#else
+    #define GGML_MEM_ALIGN 16
+#endif
+
+#define UNUSED(x) (void)(x)
+#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
+
+#define GGML_ASSERT(x) \
+    do { \
+        if (!(x)) { \
+            fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+            abort(); \
+        } \
+    } while (0)
+
+#ifdef GGML_USE_ACCELERATE
+#include <Accelerate/Accelerate.h>
+#elif GGML_USE_OPENBLAS
+#include <cblas.h>
+#endif
+
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+// floating point type used to accumulate sums
+typedef double ggml_float;
+
+// 16-bit float
+// on Arm, we use __fp16
+// on x86, we use uint16_t
+#ifdef __ARM_NEON
+
+// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
+//
+//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
+//
+#include <arm_neon.h>
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) (x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) (x)
+
+#define GGML_FP16_TO_FP32(x) (x)
+#define GGML_FP32_TO_FP16(x) (x)
+
+#else
+
+#ifdef __wasm_simd128__
+#include <wasm_simd128.h>
+#else
+#ifdef __POWER9_VECTOR__
+#include <altivec.h>
+#undef bool
+#define bool _Bool
+#else
+#include <immintrin.h>
+#endif
+#endif
+
+#ifdef __F16C__
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
+
+#else
+
+// FP16 <-> FP32
+// ref: https://github.com/Maratyszcza/FP16
+
+static inline float fp32_from_bits(uint32_t w) {
+    union {
+        uint32_t as_bits;
+        float as_value;
+    } fp32;
+    fp32.as_bits = w;
+    return fp32.as_value;
+}
+
+static inline uint32_t fp32_to_bits(float f) {
+       union {
+               float as_value;
+               uint32_t as_bits;
+       } fp32;
+       fp32.as_value = f;
+       return fp32.as_bits;
+}
+
+static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+    const uint32_t w = (uint32_t) h << 16;
+    const uint32_t sign = w & UINT32_C(0x80000000);
+    const uint32_t two_w = w + w;
+
+    const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+    const float exp_scale = 0x1.0p-112f;
+#else
+    const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
+#endif
+    const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+    const uint32_t magic_mask = UINT32_C(126) << 23;
+    const float magic_bias = 0.5f;
+    const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+    const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+    const uint32_t result = sign |
+        (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
+    return fp32_from_bits(result);
+}
+
+static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+    const float scale_to_inf = 0x1.0p+112f;
+    const float scale_to_zero = 0x1.0p-110f;
+#else
+    const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
+    const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
+#endif
+    float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+    const uint32_t w = fp32_to_bits(f);
+    const uint32_t shl1_w = w + w;
+    const uint32_t sign = w & UINT32_C(0x80000000);
+    uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+    if (bias < UINT32_C(0x71000000)) {
+        bias = UINT32_C(0x71000000);
+    }
+
+    base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+    const uint32_t bits = fp32_to_bits(base);
+    const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+    const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+    const uint32_t nonsign = exp_bits + mantissa_bits;
+    return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
+}
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+
+#endif // __F16C__
+
+#endif // __ARM_NEON
+
+//
+// global data
+//
+
+// precomputed gelu table for f16 (128 KB)
+static ggml_fp16_t table_gelu_f16[1 << 16];
+
+// precomputed silu table for f16 (128 KB)
+static ggml_fp16_t table_silu_f16[1 << 16];
+
+// precomputed exp table for f16 (128 KB)
+static ggml_fp16_t table_exp_f16[1 << 16];
+
+// precomputed f32 table for f16 (256 KB)
+static float table_f32_f16[1 << 16];
+
+// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
+// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
+#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
+
+inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
+    uint16_t s;
+    memcpy(&s, &f, sizeof(uint16_t));
+    return table_f32_f16[s];
+}
+
+#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
+#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+
+#endif
+
+// note: do not use these inside ggml.c
+// these are meant to be used via the ggml.h API
+float ggml_fp16_to_fp32(ggml_fp16_t x) {
+    return GGML_FP16_TO_FP32(x);
+}
+
+ggml_fp16_t ggml_fp32_to_fp16(float x) {
+    return GGML_FP32_TO_FP16(x);
+}
+
+//
+// timing
+//
+
+#if defined(_MSC_VER) || defined(__MINGW32__)
+static int64_t timer_freq;
+void ggml_time_init(void) {
+    LARGE_INTEGER frequency;
+    QueryPerformanceFrequency(&frequency);
+    timer_freq = frequency.QuadPart;
+}
+int64_t ggml_time_ms(void) {
+    LARGE_INTEGER t;
+    QueryPerformanceCounter(&t);
+    return (t.QuadPart * 1000) / timer_freq;
+}
+int64_t ggml_time_us(void) {
+    LARGE_INTEGER t;
+    QueryPerformanceCounter(&t);
+    return (t.QuadPart * 1000000) / timer_freq;
+}
+#else
+void ggml_time_init(void) {}
+int64_t ggml_time_ms(void) {
+    struct timespec ts;
+    clock_gettime(CLOCK_MONOTONIC, &ts);
+    return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000;
+}
+
+int64_t ggml_time_us(void) {
+    struct timespec ts;
+    clock_gettime(CLOCK_MONOTONIC, &ts);
+    return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000;
+}
+#endif
+
+int64_t ggml_cycles(void) {
+    return clock();
+}
+
+int64_t ggml_cycles_per_ms(void) {
+    return CLOCKS_PER_SEC/1000;
+}
+
+#ifdef GGML_PERF
+#define ggml_perf_time_ms()       ggml_time_ms()
+#define ggml_perf_time_us()       ggml_time_us()
+#define ggml_perf_cycles()        ggml_cycles()
+#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms()
+#else
+#define ggml_perf_time_ms()       0
+#define ggml_perf_time_us()       0
+#define ggml_perf_cycles()        0
+#define ggml_perf_cycles_per_ms() 0
+#endif
+
+//
+// cache line
+//
+
+#if defined(__cpp_lib_hardware_interference_size)
+#define CACHE_LINE_SIZE hardware_destructive_interference_size
+#else
+#if defined(__POWER9_VECTOR__)
+#define CACHE_LINE_SIZE 128
+#else
+#define CACHE_LINE_SIZE 64
+#endif
+#endif
+
+static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
+
+//
+// quantization
+//
+
+#define QK 32
+
+// method 5
+// blocks of QK elements
+// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
+    assert(k % QK == 0);
+
+    const int nb = k / QK;
+
+    float   * restrict pd = (float *)   (y);
+    uint8_t * restrict pb = (uint8_t *) (pd + nb);
+
+    uint8_t pp[QK/2];
+
+#if __ARM_NEON
+#if QK == 32
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l]  = vld1q_f32(x + i*32 + 4*l);
+        for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
+
+        for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
+        for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
+        for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
+
+        amax = MAX(
+                MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
+                MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 3) - 1);
+        const float id = d ? 1.0/d : 0.0;
+
+        pd[i] = d;
+
+        for (int l = 0; l < 8; l++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[l], id);
+            const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
+            const int32x4_t   vi = vcvtq_s32_f32(vf);
+
+            pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
+            pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
+        }
+
+        memcpy(pb + i*16, pp, sizeof(pp));
+    }
+#else
+#error "not implemented for QK"
+#endif
+#elif defined(__wasm_simd128__)
+#if QK == 32
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l]  = wasm_v128_load(x + i*32 + 4*l);
+        for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
+
+        for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
+        for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
+        for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
+
+        amax = MAX(
+                MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
+                MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 3) - 1);
+        const float id = d ? 1.0/d : 0.0;
+
+        pd[i] = d;
+
+        for (int l = 0; l < 8; l++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
+            const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
+
+            pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
+            pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
+        }
+
+        memcpy(pb + i*16, pp, sizeof(pp));
+    }
+#else
+#error "not implemented for QK"
+#endif
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int l = 0; l < QK; l++) {
+            const float v = x[i*QK + l];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 3) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        pd[i] = d;
+
+        for (int l = 0; l < QK; l += 2) {
+            const float v0 = x[i*QK + l + 0]*id;
+            const float v1 = x[i*QK + l + 1]*id;
+
+            const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
+            const uint8_t vi1 = ((int8_t) (round(v1))) + 8;
+
+            assert(vi0 >= 0 && vi0 < 16);
+            assert(vi1 >= 0 && vi1 < 16);
+
+            pp[l/2] = vi0 | (vi1 << 4);
+        }
+
+        memcpy(pb + i*QK/2, pp, sizeof(pp));
+    }
+#endif
+}
+
+// method 4
+// blocks of QK elements
+// represented with 2 floats (min + delta) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
+    assert(k % QK == 0);
+
+    const int nb = k / QK;
+
+    float   * restrict pm = (float *)   (y);
+    float   * restrict pd = (float *)   (pm + nb);
+    uint8_t * restrict pb = (uint8_t *) (pd + nb);
+
+    uint8_t pp[QK/2];
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int l = 0; l < QK; l++) {
+            const float v = x[i*QK + l];
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        pm[i] = min;
+        pd[i] = d;
+
+        for (int l = 0; l < QK; l += 2) {
+            const float v0 = (x[i*QK + l + 0] - min)*id;
+            const float v1 = (x[i*QK + l + 1] - min)*id;
+
+            const uint8_t vi0 = round(v0);
+            const uint8_t vi1 = round(v1);
+
+            assert(vi0 >= 0 && vi0 < 16);
+            assert(vi1 >= 0 && vi1 < 16);
+
+            pp[l/2] = vi0 | (vi1 << 4);
+        }
+
+        memcpy(pb + i*QK/2, pp, sizeof(pp));
+    }
+}
+
+// TODO: vectorize
+void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
+    assert(k % QK == 0);
+
+    const int nb = k / QK;
+
+    const float   * restrict pd = (const float *)   (x);
+    const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
+
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d = pd[i];
+
+        const uint8_t * restrict pp = pb + i*QK/2;
+
+        for (int l = 0; l < QK; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0xf;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = (vi0 - 8)*d;
+            const float v1 = (vi1 - 8)*d;
+
+            y[i*QK + l + 0] = v0;
+            y[i*QK + l + 1] = v1;
+
+            assert(!isnan(y[i*QK + l + 0]));
+            assert(!isnan(y[i*QK + l + 1]));
+        }
+    }
+}
+
+void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
+    assert(k % QK == 0);
+
+    const int nb = k / QK;
+
+    const float   * restrict pm = (const float *)   (x);
+    const float   * restrict pd = (const float *)   (pm + nb);
+    const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
+
+    for (int i = 0; i < nb; i++) {
+        const float m = pm[i];
+        const float d = pd[i];
+
+        const uint8_t * restrict pp = pb + i*QK/2;
+
+        for (int l = 0; l < QK; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0xf;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = vi0*d + m;
+            const float v1 = vi1*d + m;
+
+            y[i*QK + l + 0] = v0;
+            y[i*QK + l + 1] = v1;
+
+            assert(!isnan(y[i*QK + l + 0]));
+            assert(!isnan(y[i*QK + l + 1]));
+        }
+    }
+}
+
+//
+// simd mappings
+//
+
+// we define a common set of C macros which map to specific intrinsics based on the current architecture
+// we then implement the fundamental computation operations below using only these macros
+// adding support for new architectures requires to define the corresponding SIMD macros
+//
+// GGML_F32_STEP / GGML_F16_STEP
+//   number of elements to process in a single step
+//
+// GGML_F32_EPR / GGML_F16_EPR
+//   number of elements to fit in a single register
+//
+
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
+
+#define GGML_SIMD
+
+// F32 NEON
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              float32x4_t
+#define GGML_F32x4_ZERO         vdupq_n_f32(0.0f)
+#define GGML_F32x4_SET1(x)      vdupq_n_f32(x)
+#define GGML_F32x4_LOAD         vld1q_f32
+#define GGML_F32x4_STORE        vst1q_f32
+#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
+#define GGML_F32x4_ADD          vaddq_f32
+#define GGML_F32x4_MUL          vmulq_f32
+#if defined(__ARM_FEATURE_QRDMX)
+    #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
+#else
+    #define GGML_F32x4_REDUCE_ONE(x) \
+    (vgetq_lane_f32(x, 0) +          \
+     vgetq_lane_f32(x, 1) +          \
+     vgetq_lane_f32(x, 2) +          \
+     vgetq_lane_f32(x, 3))
+#endif
+#define GGML_F32x4_REDUCE(res, x)              \
+{                                              \
+    for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
+        x[2*i] = vaddq_f32(x[2*i], x[2*i+1]);  \
+    }                                          \
+    for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
+        x[4*i] = vaddq_f32(x[4*i], x[4*i+2]);  \
+    }                                          \
+    for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
+        x[8*i] = vaddq_f32(x[8*i], x[8*i+4]);  \
+    }                                          \
+    res = GGML_F32x4_REDUCE_ONE(x[0]);         \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 NEON
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+    #define GGML_F16_STEP 32
+    #define GGML_F16_EPR  8
+
+    #define GGML_F16x8              float16x8_t
+    #define GGML_F16x8_ZERO         vdupq_n_f16(0.0f)
+    #define GGML_F16x8_SET1(x)      vdupq_n_f16(x)
+    #define GGML_F16x8_LOAD         vld1q_f16
+    #define GGML_F16x8_STORE        vst1q_f16
+    #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
+    #define GGML_F16x8_ADD          vaddq_f16
+    #define GGML_F16x8_MUL          vmulq_f16
+    #define GGML_F16x8_REDUCE(res, x)                             \
+    {                                                             \
+        for (int i = 0; i < GGML_F16_ARR/2; ++i) {                \
+            x[2*i] = vaddq_f16(x[2*i], x[2*i+1]);                 \
+        }                                                         \
+        for (int i = 0; i < GGML_F16_ARR/4; ++i) {                \
+            x[4*i] = vaddq_f16(x[4*i], x[4*i+2]);                 \
+        }                                                         \
+        for (int i = 0; i < GGML_F16_ARR/8; ++i) {                \
+            x[8*i] = vaddq_f16(x[8*i], x[8*i+4]);                 \
+        }                                                         \
+        const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
+        const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
+        res = vaddvq_f32(vaddq_f32(t0, t1));                      \
+    }
+
+    #define GGML_F16_VEC                GGML_F16x8
+    #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO
+    #define GGML_F16_VEC_SET1           GGML_F16x8_SET1
+    #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p)
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
+    #define GGML_F16_VEC_FMA            GGML_F16x8_FMA
+    #define GGML_F16_VEC_ADD            GGML_F16x8_ADD
+    #define GGML_F16_VEC_MUL            GGML_F16x8_MUL
+    #define GGML_F16_VEC_REDUCE         GGML_F16x8_REDUCE
+#else
+    // if FP16 vector arithmetic is not supported, we use FP32 instead
+    // and take advantage of the vcvt_ functions to convert to/from FP16
+
+    #define GGML_F16_STEP 16
+    #define GGML_F16_EPR  4
+
+    #define GGML_F32Cx4              float32x4_t
+    #define GGML_F32Cx4_ZERO         vdupq_n_f32(0.0f)
+    #define GGML_F32Cx4_SET1(x)      vdupq_n_f32(x)
+    #define GGML_F32Cx4_LOAD(x)      vcvt_f32_f16(vld1_f16(x))
+    #define GGML_F32Cx4_STORE(x, y)  vst1_f16(x, vcvt_f16_f32(y))
+    #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
+    #define GGML_F32Cx4_ADD          vaddq_f32
+    #define GGML_F32Cx4_MUL          vmulq_f32
+    #define GGML_F32Cx4_REDUCE       GGML_F32x4_REDUCE
+
+    #define GGML_F16_VEC                GGML_F32Cx4
+    #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO
+    #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1
+    #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p)
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
+    #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA
+    #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD
+    #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL
+    #define GGML_F16_VEC_REDUCE         GGML_F32Cx4_REDUCE
+#endif
+
+#elif defined(__AVX__)
+
+#define GGML_SIMD
+
+// F32 AVX
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  8
+
+#define GGML_F32x8         __m256
+#define GGML_F32x8_ZERO    _mm256_setzero_ps()
+#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
+#define GGML_F32x8_LOAD    _mm256_loadu_ps
+#define GGML_F32x8_STORE   _mm256_storeu_ps
+#if defined(__FMA__)
+    #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
+#else
+    #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
+#endif
+#define GGML_F32x8_ADD     _mm256_add_ps
+#define GGML_F32x8_MUL     _mm256_mul_ps
+#define GGML_F32x8_REDUCE(res, x)                                 \
+{                                                                 \
+    for (int i = 0; i < GGML_F32_ARR/2; ++i) {                    \
+        x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]);                 \
+    }                                                             \
+    for (int i = 0; i < GGML_F32_ARR/4; ++i) {                    \
+        x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]);                 \
+    }                                                             \
+    for (int i = 0; i < GGML_F32_ARR/8; ++i) {                    \
+        x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]);                 \
+    }                                                             \
+    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]),    \
+                                 _mm256_extractf128_ps(x[0], 1)); \
+    const __m128 t1 = _mm_hadd_ps(t0, t0);                        \
+    res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));                     \
+}
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x8
+#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
+
+// F16 AVX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  8
+
+// F16 arithmetic is not supported by AVX, so we use F32 instead
+// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32
+
+#define GGML_F32Cx8             __m256
+#define GGML_F32Cx8_ZERO        _mm256_setzero_ps()
+#define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x)
+#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
+#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
+#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
+#define GGML_F32Cx8_ADD         _mm256_add_ps
+#define GGML_F32Cx8_MUL         _mm256_mul_ps
+#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
+
+#define GGML_F16_VEC                GGML_F32Cx8
+#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
+
+#elif defined(__POWER9_VECTOR__)
+
+#define GGML_SIMD
+
+// F32 POWER9
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              vector float
+#define GGML_F32x4_ZERO         0.0f
+#define GGML_F32x4_SET1         vec_splats
+#define GGML_F32x4_LOAD(p)      vec_xl(0, p)
+#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)
+#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
+#define GGML_F32x4_ADD          vec_add
+#define GGML_F32x4_MUL          vec_mul
+#define GGML_F32x4_REDUCE(res, x)              \
+{                                              \
+    for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
+        x[2*i] = vec_add(x[2*i], x[2*i+1]);    \
+    }                                          \
+    for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
+        x[4*i] = vec_add(x[4*i], x[4*i+2]);    \
+    }                                          \
+    for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
+        x[8*i] = vec_add(x[8*i], x[8*i+4]);    \
+    }                                          \
+    res = vec_extract(x[0], 0) +               \
+          vec_extract(x[0], 1) +               \
+          vec_extract(x[0], 2) +               \
+          vec_extract(x[0], 3);                \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 POWER9
+#define GGML_F16_STEP       GGML_F32_STEP
+#define GGML_F16_EPR        GGML_F32_EPR
+#define GGML_F16_VEC        GGML_F32x4
+#define GGML_F16_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F16_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F16_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
+// Use vec_xl, not vec_ld, in case the load address is not aligned.
+#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ?                   \
+  vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
+  vec_extract_fp32_from_shortl(vec_xl(0, p))
+#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
+#define GGML_F16_VEC_STORE(p, r, i)                             \
+  if (i & 0x1)                                                  \
+    vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)],  \
+                                   r[i - GGML_ENDIAN_BYTE(0)]), \
+            0, p - GGML_F16_EPR)
+
+#elif defined(__wasm_simd128__)
+
+#define GGML_SIMD
+
+// F32 WASM
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              v128_t
+#define GGML_F32x4_ZERO         wasm_f32x4_splat(0.0f)
+#define GGML_F32x4_SET1(x)      wasm_f32x4_splat(x)
+#define GGML_F32x4_LOAD         wasm_v128_load
+#define GGML_F32x4_STORE        wasm_v128_store
+#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
+#define GGML_F32x4_ADD          wasm_f32x4_add
+#define GGML_F32x4_MUL          wasm_f32x4_mul
+#define GGML_F32x4_REDUCE(res, x)                  \
+{                                                  \
+    for (int i = 0; i < GGML_F32_ARR/2; ++i) {     \
+        x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \
+    }                                              \
+    for (int i = 0; i < GGML_F32_ARR/4; ++i) {     \
+        x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \
+    }                                              \
+    for (int i = 0; i < GGML_F32_ARR/8; ++i) {     \
+        x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \
+    }                                              \
+    res = wasm_f32x4_extract_lane(x[0], 0) +       \
+          wasm_f32x4_extract_lane(x[0], 1) +       \
+          wasm_f32x4_extract_lane(x[0], 2) +       \
+          wasm_f32x4_extract_lane(x[0], 3);        \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 WASM
+
+#define GGML_F16_STEP 16
+#define GGML_F16_EPR  4
+
+inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
+    float tmp[4];
+
+    tmp[0] = GGML_FP16_TO_FP32(p[0]);
+    tmp[1] = GGML_FP16_TO_FP32(p[1]);
+    tmp[2] = GGML_FP16_TO_FP32(p[2]);
+    tmp[3] = GGML_FP16_TO_FP32(p[3]);
+
+    return wasm_v128_load(tmp);
+}
+
+inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
+    float tmp[4];
+
+    wasm_v128_store(tmp, x);
+
+    p[0] = GGML_FP32_TO_FP16(tmp[0]);
+    p[1] = GGML_FP32_TO_FP16(tmp[1]);
+    p[2] = GGML_FP32_TO_FP16(tmp[2]);
+    p[3] = GGML_FP32_TO_FP16(tmp[3]);
+}
+
+#define GGML_F16x4             v128_t
+#define GGML_F16x4_ZERO        wasm_f32x4_splat(0.0f)
+#define GGML_F16x4_SET1(x)     wasm_f32x4_splat(x)
+#define GGML_F16x4_LOAD(x)     __wasm_f16x4_load(x)
+#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
+#define GGML_F16x4_FMA         GGML_F32x4_FMA
+#define GGML_F16x4_ADD         wasm_f32x4_add
+#define GGML_F16x4_MUL         wasm_f32x4_mul
+#define GGML_F16x4_REDUCE(res, x)                  \
+{                                                  \
+    for (int i = 0; i < GGML_F16_ARR/2; ++i) {     \
+        x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \
+    }                                              \
+    for (int i = 0; i < GGML_F16_ARR/4; ++i) {     \
+        x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \
+    }                                              \
+    for (int i = 0; i < GGML_F16_ARR/8; ++i) {     \
+        x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \
+    }                                              \
+    res = wasm_f32x4_extract_lane(x[0], 0) +       \
+          wasm_f32x4_extract_lane(x[0], 1) +       \
+          wasm_f32x4_extract_lane(x[0], 2) +       \
+          wasm_f32x4_extract_lane(x[0], 3);        \
+}
+
+#define GGML_F16_VEC                GGML_F16x4
+#define GGML_F16_VEC_ZERO           GGML_F16x4_ZERO
+#define GGML_F16_VEC_SET1           GGML_F16x4_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F16x4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F16x4_FMA
+#define GGML_F16_VEC_ADD            GGML_F16x4_ADD
+#define GGML_F16_VEC_MUL            GGML_F16x4_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F16x4_REDUCE
+
+#elif defined(__SSE3__)
+
+#define GGML_SIMD
+
+// F32 SSE
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4         __m128
+#define GGML_F32x4_ZERO    _mm_setzero_ps()
+#define GGML_F32x4_SET1(x) _mm_set1_ps(x)
+#define GGML_F32x4_LOAD    _mm_loadu_ps
+#define GGML_F32x4_STORE   _mm_storeu_ps
+#if defined(__FMA__)
+    // TODO: Does this work?
+    #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
+#else
+    #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
+#endif
+#define GGML_F32x4_ADD     _mm_add_ps
+#define GGML_F32x4_MUL     _mm_mul_ps
+#define GGML_F32x4_REDUCE(res, x)                                 \
+{                                                                 \
+    for (int i = 0; i < GGML_F32_ARR/2; ++i) {                    \
+        x[2*i] = _mm_add_ps(x[2*i], x[2*i+1]);                    \
+    }                                                             \
+    for (int i = 0; i < GGML_F32_ARR/4; ++i) {                    \
+        x[4*i] = _mm_add_ps(x[4*i], x[4*i+2]);                    \
+    }                                                             \
+    for (int i = 0; i < GGML_F32_ARR/8; ++i) {                    \
+        x[8*i] = _mm_add_ps(x[8*i], x[8*i+4]);                    \
+    }                                                             \
+    const __m128 t0 = _mm_hadd_ps(x[0], x[0]);                    \
+    res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0));                     \
+}
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 SSE
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  4
+
+static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
+    float tmp[4];
+
+    tmp[0] = GGML_FP16_TO_FP32(x[0]);
+    tmp[1] = GGML_FP16_TO_FP32(x[1]);
+    tmp[2] = GGML_FP16_TO_FP32(x[2]);
+    tmp[3] = GGML_FP16_TO_FP32(x[3]);
+
+    return _mm_loadu_ps(tmp);
+}
+
+static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
+    float arr[4];
+
+    _mm_storeu_ps(arr, y);
+
+    x[0] = GGML_FP32_TO_FP16(arr[0]);
+    x[1] = GGML_FP32_TO_FP16(arr[1]);
+    x[2] = GGML_FP32_TO_FP16(arr[2]);
+    x[3] = GGML_FP32_TO_FP16(arr[3]);
+}
+
+#define GGML_F32Cx4             __m128
+#define GGML_F32Cx4_ZERO        _mm_setzero_ps()
+#define GGML_F32Cx4_SET1(x)     _mm_set1_ps(x)
+#define GGML_F32Cx4_LOAD(x)     __sse_f16x4_load(x)
+#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
+#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
+#define GGML_F32Cx4_ADD         _mm_add_ps
+#define GGML_F32Cx4_MUL         _mm_mul_ps
+#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
+
+#define GGML_F16_VEC                 GGML_F32Cx4
+#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
+#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
+#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
+#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
+#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
+#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
+
+#endif
+
+// GGML_F32_ARR / GGML_F16_ARR
+//   number of registers to use per step
+#ifdef GGML_SIMD
+#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
+#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
+#endif
+
+//
+// fundamental operations
+//
+
+inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
+inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
+inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
+inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
+inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
+inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
+inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
+inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
+inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
+
+inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
+    ggml_float sumf = 0.0;
+
+#ifdef GGML_SIMD
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+            sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    GGML_F32_VEC_REDUCE(sumf, sum);
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        sumf += x[i]*y[i];
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        sumf += x[i]*y[i];
+    }
+#endif
+
+    *s = sumf;
+}
+
+inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
+    ggml_float sumf = 0.0;
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+            sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    GGML_F16_VEC_REDUCE(sumf, sum);
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
+    }
+#endif
+
+    *s = sumf;
+}
+
+inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+    const int nb = n / QK;
+
+    assert(n % QK == 0);
+    assert(nb % 2 == 0);
+
+    const float * restrict pd0 = (const float *) x;
+    const float * restrict pd1 = (const float *) y;
+
+    const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
+    const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
+
+    float sumf = 0.0;
+
+#ifdef __ARM_NEON
+#if QK == 32
+    float sum0 = 0.0f;
+    float sum1 = 0.0f;
+
+    for (int i = 0; i < nb; i += 2) {
+        const float d0_0 = pd0[i + 0];
+        const float d1_0 = pd1[i + 0];
+        const float d0_1 = pd0[i + 1];
+        const float d1_1 = pd1[i + 1];
+
+        //printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1);
+
+        const uint8_t * restrict p0 = pb0 + i*16;
+        const uint8_t * restrict p1 = pb1 + i*16;
+
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
+
+        const uint8x16_t v0_0 = vld1q_u8(p0);
+        const uint8x16_t v1_0 = vld1q_u8(p1);
+        const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
+        const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
+        const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
+
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
+
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
+        const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
+
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+        const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
+
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
+
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
+
+        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+        const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
+
+        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+        const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
+
+        // dot product into int16x8_t
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
+
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
+
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
+
+        const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
+        const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
+
+        const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
+        const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
+
+        const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
+        const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
+
+        // scalar
+#if defined(__ARM_FEATURE_QRDMX)
+        sum0 += d0_0*d1_0*vaddvq_s16(p_0);
+        sum1 += d0_1*d1_1*vaddvq_s16(p_1);
+#else
+        sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
+        sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
+#endif
+    }
+
+    sumf = sum0 + sum1;
+#else
+#error "not implemented for QK"
+#endif
+#elif defined(__wasm_simd128__)
+#if QK == 32
+    // wasm simd
+    float sum0 = 0.0f;
+    float sum1 = 0.0f;
+
+    for (int i = 0; i < nb; i += 2) {
+        const float d0_0 = pd0[i + 0];
+        const float d0_1 = pd0[i + 1];
+        const float d1_0 = pd1[i + 0];
+        const float d1_1 = pd1[i + 1];
+
+        const uint8_t * restrict p0 = pb0 + i*16;
+        const uint8_t * restrict p1 = pb1 + i*16;
+
+        const v128_t m4b = wasm_u8x16_splat(0xf);
+        const v128_t s8b = wasm_i8x16_splat(0x8);
+
+        const v128_t v0_0 = wasm_v128_load(p0);
+        const v128_t v0_1 = wasm_v128_load(p0 + 16);
+        const v128_t v1_0 = wasm_v128_load(p1);
+        const v128_t v1_1 = wasm_v128_load(p1 + 16);
+
+        // 4-bit -> 8-bit
+        const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
+        const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
+
+        const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
+        const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
+
+        const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
+        const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
+
+        const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
+        const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
+
+        // sub 8
+        const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
+        const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
+
+        const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
+        const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
+
+        const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
+        const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
+
+        const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
+        const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
+
+        // dot product into int16x8_t
+        const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
+        const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
+
+        const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
+        const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
+
+        const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
+        const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
+
+        const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
+        const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
+
+        const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
+        const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
+
+        const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
+        const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
+
+        const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
+        const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
+
+        sum0 += d0_0*d1_0*(
+                wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
+                wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
+                wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
+                wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
+        sum1 += d0_1*d1_1*(
+                wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
+                wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
+                wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
+                wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
+    }
+
+    sumf = sum0 + sum1;
+#else
+#error "not implemented for QK"
+#endif
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d0 = pd0[i];
+        const float d1 = pd1[i];
+
+        const uint8_t * restrict p0 = pb0 + i*QK/2;
+        const uint8_t * restrict p1 = pb1 + i*QK/2;
+
+        for (int j = 0; j < QK/2; j++) {
+            const uint8_t v0 = p0[j];
+            const uint8_t v1 = p1[j];
+
+            const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
+            const float f1 = d0*((int8_t) (v0 >> 4)  - 8);
+
+            const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
+            const float f3 = d1*((int8_t) (v1 >> 4)  - 8);
+
+            sumf += f0*f2 + f1*f3;
+        }
+    }
+#endif
+
+    *s = sumf;
+}
+
+inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+    const int nb = n / QK;
+
+    const float * restrict pm0 = (const float *) x;
+    const float * restrict pm1 = (const float *) y;
+
+    const float * restrict pd0 = (const float *) (pm0 + nb);
+    const float * restrict pd1 = (const float *) (pm1 + nb);
+
+    const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
+    const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
+
+    float sumf = 0.0;
+
+#if 1
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float m0 = pm0[i];
+        const float m1 = pm1[i];
+
+        const float d0 = pd0[i];
+        const float d1 = pd1[i];
+
+        const uint8_t * restrict p0 = pb0 + i*QK/2;
+        const uint8_t * restrict p1 = pb1 + i*QK/2;
+
+        for (int j = 0; j < QK/2; j++) {
+            const uint8_t v0 = p0[j];
+            const uint8_t v1 = p1[j];
+
+            const float f0 = d0*(v0 & 0xf) + m0;
+            const float f1 = d0*(v0 >> 4)  + m0;
+
+            const float f2 = d1*(v1 & 0xf) + m1;
+            const float f3 = d1*(v1 >> 4)  + m1;
+
+            sumf += f0*f2 + f1*f3;
+        }
+    }
+#endif
+
+    *s = sumf;
+}
+
+// compute GGML_VEC_DOT_UNROLL dot products at once
+// xs - x row stride in bytes
+inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
+    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
+
+    ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
+    }
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+            for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+                ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
+
+                sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
+            }
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+        GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
+        }
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
+        }
+    }
+#endif
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        s[i] = sumf[i];
+    }
+}
+
+inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+#endif
+}
+
+inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        GGML_ASSERT(false);
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+    }
+#endif
+}
+
+inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * restrict x, const float v) {
+    assert(n % QK == 0);
+
+    const int nb = n / QK;
+
+    const float   * restrict pd = (const float *)   (x);
+    const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
+
+#if __ARM_NEON
+#if QK == 32
+    for (int i = 0; i < nb; ++i) {
+        const float d0 = pd[i]*v;
+
+        const uint8_t * restrict pp = pb + i*16;
+
+        const uint8x8_t m4b = vdup_n_u8(0xf);
+        const int8x8_t  s8b = vdup_n_s8(0x8);
+
+        const float32x4_t vd = vdupq_n_f32(d0);
+
+        for (int j = 0; j < 2; j++) {
+            const uint8x8_t vx = vld1_u8(pp + j*8);
+
+            const int8x8_t vxl = vreinterpret_s8_u8(vand_u8(vx, m4b));
+            const int8x8_t vxh = vreinterpret_s8_u8(vshr_n_u8(vx, 4));
+
+            // sub 8
+            const int8x8_t vxls = vsub_s8(vxl, s8b);
+            const int8x8_t vxhs = vsub_s8(vxh, s8b);
+
+            //const int8x8_t vxlt = vzip_s8(vxls, vxhs)[0];
+            //const int8x8_t vxht = vzip_s8(vxls, vxhs)[1];
+            const int8x8_t vxlt = vzip1_s8(vxls, vxhs);
+            const int8x8_t vxht = vzip2_s8(vxls, vxhs);
+
+            const int8x16_t vxq = vcombine_s8(vxlt, vxht);
+
+            // convert to 2x int16x8_t
+            const int16x8_t vxq0 = vmovl_s8(vget_low_s8 (vxq));
+            const int16x8_t vxq1 = vmovl_s8(vget_high_s8(vxq));
+
+            // convert to 4x float32x4_t
+            const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq0)));
+            const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq0)));
+            const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq1)));
+            const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq1)));
+
+            const float32x4_t vy0 = vld1q_f32(y + i*32 + j*16 + 0);
+            const float32x4_t vy1 = vld1q_f32(y + i*32 + j*16 + 4);
+            const float32x4_t vy2 = vld1q_f32(y + i*32 + j*16 + 8);
+            const float32x4_t vy3 = vld1q_f32(y + i*32 + j*16 + 12);
+
+            const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd);
+            const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd);
+            const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd);
+            const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd);
+
+            vst1q_f32(y + i*32 + j*16 + 0,  vr0);
+            vst1q_f32(y + i*32 + j*16 + 4,  vr1);
+            vst1q_f32(y + i*32 + j*16 + 8,  vr2);
+            vst1q_f32(y + i*32 + j*16 + 12, vr3);
+        }
+    }
+#endif
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d = pd[i];
+
+        const uint8_t * restrict pp = pb + i*QK/2;
+
+        for (int l = 0; l < QK; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0xf;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = (vi0 - 8)*d;
+            const float v1 = (vi1 - 8)*d;
+
+            y[i*QK + l + 0] += v0*v;
+            y[i*QK + l + 1] += v1*v;
+
+            assert(!isnan(y[i*QK + l + 0]));
+            assert(!isnan(y[i*QK + l + 1]));
+            assert(!isinf(y[i*QK + l + 0]));
+            assert(!isinf(y[i*QK + l + 1]));
+        }
+    }
+#endif
+}
+
+inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) {
+    assert(n % QK == 0);
+
+    const int nb = n / QK;
+
+    const float   * restrict pm = (const float *)   (x);
+    const float   * restrict pd = (const float *)   (pm + nb);
+    const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
+
+    for (int i = 0; i < nb; i++) {
+        const float m = pm[i];
+        const float d = pd[i];
+
+        const uint8_t * restrict pp = pb + i*QK/2;
+
+        for (int l = 0; l < QK; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const uint8_t vi0 = vi & 0xf;
+            const uint8_t vi1 = vi >> 4;
+
+            const float v0 = d*vi0 + m;
+            const float v1 = d*vi1 + m;
+
+            y[i*QK + l + 0] += v0*v;
+            y[i*QK + l + 1] += v1*v;
+
+            assert(!isnan(y[i*QK + l + 0]));
+            assert(!isnan(y[i*QK + l + 1]));
+            assert(!isinf(y[i*QK + l + 0]));
+            assert(!isinf(y[i*QK + l + 1]));
+            //printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
+        }
+    }
+}
+
+//inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
+inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] *= v;
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] *= v;
+    }
+#endif
+}
+
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s);   }
+inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
+inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
+inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
+inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
+inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
+inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+
+static const ggml_float GELU_COEF_A    = 0.044715;
+static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
+
+inline static float ggml_gelu_f32(float x) {
+    return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
+}
+
+inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    const uint16_t * i16 = (const uint16_t *) x;
+    for (int i = 0; i < n; ++i) {
+        y[i] = table_gelu_f16[i16[i]];
+    }
+}
+
+#ifdef GGML_GELU_FP16
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+        memcpy(&t, &fp16, sizeof(uint16_t));
+        y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]);
+    }
+}
+#else
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_gelu_f32(x[i]);
+    }
+}
+#endif
+
+// Sigmoid Linear Unit (SiLU) function
+inline static float ggml_silu_f32(float x) {
+    return x/(1.0 + exp(-x));
+}
+
+inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    const uint16_t * i16 = (const uint16_t *) x;
+    for (int i = 0; i < n; ++i) {
+        y[i] = table_silu_f16[i16[i]];
+    }
+}
+
+#ifdef GGML_SILU_FP16
+inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+        memcpy(&t, &fp16, sizeof(uint16_t));
+        y[i] = GGML_FP16_TO_FP32(table_silu_f16[t]);
+    }
+}
+#else
+inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_silu_f32(x[i]);
+    }
+}
+#endif
+
+inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
+#ifndef GGML_USE_ACCELERATE
+    ggml_float sum = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sum += x[i];
+    }
+    *s = sum;
+#else
+    vDSP_sve(x, 1, s, n);
+#endif
+}
+
+inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
+#ifndef GGML_USE_ACCELERATE
+    ggml_float max = -INFINITY;
+    for (int i = 0; i < n; ++i) {
+        max = MAX(max, x[i]);
+    }
+    *s = max;
+#else
+    vDSP_maxv(x, 1, s, n);
+#endif
+}
+
+inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }
+
+//
+// logging
+//
+
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) printf(__VA_ARGS__)
+
+//
+// data types
+//
+
+static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
+    QK,
+    QK,
+    1,
+    1,
+    1,
+    1,
+    1,
+};
+
+static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
+
+static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
+    sizeof(float  )   + QK/2,
+    sizeof(float  )*2 + QK/2,
+    sizeof(int8_t ),
+    sizeof(int16_t),
+    sizeof(int32_t),
+    sizeof(ggml_fp16_t),
+    sizeof(float  ),
+};
+
+// don't forget to update the array above when adding new types
+static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
+
+static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
+    "NONE",
+
+    "DUP",
+    "ADD",
+    "SUB",
+    "MUL",
+    "DIV",
+    "SQR",
+    "SQRT",
+    "SUM",
+    "MEAN",
+    "REPEAT",
+    "ABS",
+    "SGN",
+    "NEG",
+    "STEP",
+    "RELU",
+    "GELU",
+    "SILU",
+    "NORM",
+
+    "MUL_MAT",
+
+    "SCALE",
+    "CPY",
+    "RESHAPE",
+    "VIEW",
+    "PERMUTE",
+    "TRANSPOSE",
+    "GET_ROWS",
+    "DIAG_MASK_INF",
+    "SOFT_MAX",
+    "ROPE",
+    "CONV_1D_1S",
+    "CONV_1D_2S",
+
+    "FLASH_ATTN",
+    "FLASH_FF",
+};
+
+static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
+
+static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+    "none",
+
+    "x",
+    "x+y",
+    "x-y",
+    "x*y",
+    "x/y",
+    "x^2",
+    "√x",
+    "Σx",
+    "Σx/n",
+    "repeat(x)",
+    "abs(x)",
+    "sgn(x)",
+    "-x",
+    "step(x)",
+    "relu(x)",
+    "gelu(x)",
+    "silu(x)",
+    "norm(x)",
+
+    "X*Y",
+
+    "x*v",
+    "x-\\>y",
+    "reshape(x)",
+    "view(x)",
+    "permute(x)",
+    "transpose(x)",
+    "get_rows(x)",
+    "diag_mask_inf(x)",
+    "soft_max(x)",
+    "rope(x)",
+    "conv_1d_1s(x)",
+    "conv_1d_2s(x)",
+
+    "flash_attn(x)",
+    "flash_ff(x)",
+};
+
+static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
+
+//
+// ggml object
+//
+
+struct ggml_object {
+    size_t offs;
+    size_t size;
+
+    struct ggml_object * next;
+
+    char padding[8];
+};
+
+static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
+
+static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
+static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
+
+//
+// ggml context
+//
+
+struct ggml_context {
+    size_t mem_size;
+    void * mem_buffer;
+    bool   mem_buffer_owned;
+
+    int n_objects;
+
+    struct ggml_object * objects_begin;
+    struct ggml_object * objects_end;
+
+    struct ggml_scratch scratch;
+    struct ggml_scratch scratch_save;
+};
+
+struct ggml_context_container {
+    bool used;
+
+    struct ggml_context context;
+};
+
+//
+// compute types
+//
+
+enum ggml_task_type {
+    GGML_TASK_INIT = 0,
+    GGML_TASK_COMPUTE,
+    GGML_TASK_FINALIZE,
+};
+
+struct ggml_compute_params {
+    enum ggml_task_type type;
+
+    int ith, nth;
+
+    // work buffer for all threads
+    size_t wsize;
+    void * wdata;
+};
+
+//
+// ggml state
+//
+
+struct ggml_state {
+    struct ggml_context_container contexts[GGML_MAX_CONTEXTS];
+};
+
+// global state
+static struct ggml_state g_state;
+static atomic_int g_state_barrier = 0;
+
+// barrier via spin lock
+inline static void ggml_critical_section_start(void) {
+    int processing = atomic_fetch_add(&g_state_barrier, 1);
+
+    while (processing > 0) {
+        // wait for other threads to finish
+        atomic_fetch_sub(&g_state_barrier, 1);
+        sched_yield(); // TODO: reconsider this
+        processing = atomic_fetch_add(&g_state_barrier, 1);
+    }
+}
+
+// TODO: make this somehow automatically executed
+//       some sort of "sentry" mechanism
+inline static void ggml_critical_section_end(void) {
+    atomic_fetch_sub(&g_state_barrier, 1);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_print_object(const struct ggml_object * obj) {
+    GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n",
+            obj->offs, obj->size, (const void *) obj->next);
+}
+
+void ggml_print_objects(const struct ggml_context * ctx) {
+    struct ggml_object * obj = ctx->objects_begin;
+
+    GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx);
+
+    while (obj != NULL) {
+        ggml_print_object(obj);
+        obj = obj->next;
+    }
+
+    GGML_PRINT("%s: --- end ---\n", __func__);
+}
+
+int ggml_nelements(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
+}
+
+int ggml_nrows(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
+}
+
+size_t ggml_nbytes(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
+}
+
+int ggml_blck_size(enum ggml_type type) {
+    return GGML_BLCK_SIZE[type];
+}
+
+size_t ggml_type_size(enum ggml_type type) {
+    return GGML_TYPE_SIZE[type];
+}
+
+float ggml_type_sizef(enum ggml_type type) {
+    return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
+}
+
+size_t ggml_element_size(const struct ggml_tensor * tensor) {
+    return GGML_TYPE_SIZE[tensor->type];
+}
+
+static inline bool ggml_is_scalar(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+static inline bool ggml_is_vector(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        (t0->ne[0]  == t1->ne[0])  &&
+        (t0->ne[2]  == t1->ne[2])  &&
+        (t0->ne[3]  == t1->ne[3]);
+}
+
+static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
+        tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/GGML_BLCK_SIZE[tensor->type] &&
+        tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
+static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
+        tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
+static inline bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        (t0->ne[0] == t1->ne[0] ) &&
+        (t0->ne[1] == t1->ne[1] ) &&
+        (t0->ne[2] == t1->ne[2] ) &&
+        (t0->ne[3] == t1->ne[3] );
+}
+
+// check if t1 can be represented as a repeatition of t0
+static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        (t1->ne[0]%t0->ne[0] == 0) &&
+        (t1->ne[1]%t0->ne[1] == 0) &&
+        (t1->ne[2]%t0->ne[2] == 0) &&
+        (t1->ne[3]%t0->ne[3] == 0);
+}
+
+static inline int ggml_up32(int n) {
+    return (n + 31) & ~31;
+}
+
+static inline int ggml_up64(int n) {
+    return (n + 63) & ~63;
+}
+
+static inline int ggml_up(int n, int m) {
+    // assert m is a power of 2
+    GGML_ASSERT((m & (m - 1)) == 0);
+    return (n + m - 1) & ~(m - 1);
+}
+
+// assert that pointer is aligned to GGML_MEM_ALIGN
+#define ggml_assert_aligned(ptr) \
+    assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct ggml_context * ggml_init(struct ggml_init_params params) {
+    // make this function thread safe
+    ggml_critical_section_start();
+
+    static bool is_first_call = true;
+
+    if (is_first_call) {
+        // initialize GELU, SILU and EXP F32 tables
+        {
+            const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
+
+            ggml_fp16_t ii;
+            for (int i = 0; i < (1 << 16); ++i) {
+                uint16_t ui = i;
+                memcpy(&ii, &ui, sizeof(ii));
+                const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
+                table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
+                table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
+                table_exp_f16[i]  = GGML_FP32_TO_FP16(exp(f));
+            }
+
+            const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
+
+            GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+        }
+
+        // initialize g_state
+        {
+            const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
+
+            g_state = (struct ggml_state) {
+                /*.contexts =*/ { { 0 } },
+            };
+
+            for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
+                g_state.contexts[i].used = false;
+            }
+
+            const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
+
+            GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+        }
+
+        is_first_call = false;
+    }
+
+    // find non-used context in g_state
+    struct ggml_context * ctx = NULL;
+
+    for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
+        if (!g_state.contexts[i].used) {
+            g_state.contexts[i].used = true;
+            ctx = &g_state.contexts[i].context;
+
+            GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i);
+            break;
+        }
+    }
+
+    if (ctx == NULL) {
+        GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);
+
+        ggml_critical_section_end();
+
+        return NULL;
+    }
+
+    *ctx = (struct ggml_context) {
+        /*.mem_size         =*/ params.mem_size,
+        /*.mem_buffer       =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
+        /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
+        /*.n_objects        =*/ 0,
+        /*.objects_begin    =*/ NULL,
+        /*.objects_end      =*/ NULL,
+        /*.scratch          =*/ { 0, 0, NULL, },
+        /*.scratch_save     =*/ { 0, 0, NULL, },
+    };
+
+    ggml_assert_aligned(ctx->mem_buffer);
+
+    GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
+
+    ggml_critical_section_end();
+
+    return ctx;
+}
+
+void ggml_free(struct ggml_context * ctx) {
+    // make this function thread safe
+    ggml_critical_section_start();
+
+    bool found = false;
+
+    for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
+        if (&g_state.contexts[i].context == ctx) {
+            g_state.contexts[i].used = false;
+
+            GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
+                    __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
+
+            if (ctx->mem_buffer_owned) {
+                free(ctx->mem_buffer);
+            }
+
+            found = true;
+            break;
+        }
+    }
+
+    if (!found) {
+        GGML_PRINT_DEBUG("%s: context not found\n", __func__);
+    }
+
+    ggml_critical_section_end();
+}
+
+size_t ggml_used_mem(const struct ggml_context * ctx) {
+    return ctx->objects_end->offs + ctx->objects_end->size;
+}
+
+size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
+    const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0;
+
+    ctx->scratch = scratch;
+
+    return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct ggml_tensor * ggml_new_tensor_impl(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    n_dims,
+        const int* ne,
+        void*  data) {
+    // always insert objects at the end of the context's memory pool
+    struct ggml_object * obj_cur = ctx->objects_end;
+
+    const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs;
+    const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
+    const size_t cur_end  = cur_offs + cur_size;
+
+    size_t size_needed = 0;
+
+    if (data == NULL) {
+        size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
+        for (int i = 1; i < n_dims; i++) {
+            size_needed *= ne[i];
+        }
+        // align to GGML_MEM_ALIGN
+        size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN;
+    }
+
+    char * const mem_buffer = ctx->mem_buffer;
+    struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
+
+    if (ctx->scratch.data == NULL || data != NULL) {
+        size_needed += sizeof(struct ggml_tensor);
+
+        if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
+            GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
+                    __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
+            assert(false);
+            return NULL;
+        }
+
+        *obj_new = (struct ggml_object) {
+            .offs = cur_end + GGML_OBJECT_SIZE,
+            .size = size_needed,
+            .next = NULL,
+        };
+    } else {
+        if (ctx->scratch.offs + size_needed > ctx->scratch.size) {
+            GGML_PRINT("%s: not enough space in the scratch memory\n", __func__);
+            assert(false);
+            return NULL;
+        }
+
+        if (cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE > ctx->mem_size) {
+            GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
+                    __func__, cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE, ctx->mem_size);
+            assert(false);
+            return NULL;
+        }
+
+        data = (char * const) ctx->scratch.data + ctx->scratch.offs;
+
+        *obj_new = (struct ggml_object) {
+            .offs = cur_end + GGML_OBJECT_SIZE,
+            .size = sizeof(struct ggml_tensor),
+            .next = NULL,
+        };
+
+        //printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed);
+
+        ctx->scratch.offs += size_needed;
+    }
+
+    if (obj_cur != NULL) {
+        obj_cur->next = obj_new;
+    } else {
+        // this is the first object in this context
+        ctx->objects_begin = obj_new;
+    }
+
+    ctx->objects_end = obj_new;
+
+    //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
+
+    struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs);
+
+    ggml_assert_aligned(result);
+
+    *result = (struct ggml_tensor) {
+        /*.type         =*/ type,
+        /*.n_dims       =*/ n_dims,
+        /*.ne           =*/ { 1, 1, 1, 1 },
+        /*.nb           =*/ { 0, 0, 0, 0 },
+        /*.op           =*/ GGML_OP_NONE,
+        /*.is_param     =*/ false,
+        /*.grad         =*/ NULL,
+        /*.src0         =*/ NULL,
+        /*.src1         =*/ NULL,
+        /*.opt          =*/ { NULL },
+        /*.n_tasks      =*/ 0,
+        /*.perf_runs    =*/ 0,
+        /*.perf_cycles  =*/ 0,
+        /*.perf_time_us =*/ 0,
+        /*.data         =*/ data == NULL ? (void *)(result + 1) : data,
+        /*.pad          =*/ { 0 },
+    };
+
+    ggml_assert_aligned(result->data);
+
+    for (int i = 0; i < n_dims; i++) {
+        result->ne[i] = ne[i];
+    }
+
+    result->nb[0] = GGML_TYPE_SIZE[type];
+    result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]);
+    for (int i = 2; i < GGML_MAX_DIMS; i++) {
+        result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
+    }
+
+    ctx->n_objects++;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_new_tensor(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    n_dims,
+        const int * ne) {
+    return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
+}
+
+struct ggml_tensor * ggml_new_tensor_1d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0) {
+    return ggml_new_tensor(ctx, type, 1, &ne0);
+}
+
+struct ggml_tensor * ggml_new_tensor_2d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0,
+        int    ne1) {
+    const int ne[2] = { ne0, ne1 };
+    return ggml_new_tensor(ctx, type, 2, ne);
+}
+
+struct ggml_tensor * ggml_new_tensor_3d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0,
+        int    ne1,
+        int    ne2) {
+    const int ne[3] = { ne0, ne1, ne2 };
+    return ggml_new_tensor(ctx, type, 3, ne);
+}
+
+struct ggml_tensor * ggml_new_tensor_4d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0,
+        int    ne1,
+        int    ne2,
+        int    ne3) {
+    const int ne[4] = { ne0, ne1, ne2, ne3 };
+    return ggml_new_tensor(ctx, type, 4, ne);
+}
+
+struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
+    ctx->scratch_save = ctx->scratch;
+    ctx->scratch.data = NULL;
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
+
+    ctx->scratch = ctx->scratch_save;
+
+    ggml_set_i32(result, value);
+
+    return result;
+}
+
+struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
+    ctx->scratch_save = ctx->scratch;
+    ctx->scratch.data = NULL;
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+
+    ctx->scratch = ctx->scratch_save;
+
+    ggml_set_f32(result, value);
+
+    return result;
+}
+
+struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) {
+    return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL);
+}
+
+struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
+    memset(tensor->data, 0, ggml_nbytes(tensor));
+    return tensor;
+}
+
+struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
+    const int n     = ggml_nrows(tensor);
+    const int nc    = tensor->ne[0];
+    const size_t n1 = tensor->nb[1];
+
+    char * const data = tensor->data;
+
+    switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_I8:
+            {
+                assert(tensor->nb[0] == sizeof(int8_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I16:
+            {
+                assert(tensor->nb[0] == sizeof(int16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I32:
+            {
+                assert(tensor->nb[0] == sizeof(int32_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_F16:
+            {
+                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_F32:
+            {
+                assert(tensor->nb[0] == sizeof(float));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+    return tensor;
+}
+
+struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
+    const int n     = ggml_nrows(tensor);
+    const int nc    = tensor->ne[0];
+    const size_t n1 = tensor->nb[1];
+
+    char * const data = tensor->data;
+
+    switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_I8:
+            {
+                assert(tensor->nb[0] == sizeof(int8_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I16:
+            {
+                assert(tensor->nb[0] == sizeof(int16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I32:
+            {
+                assert(tensor->nb[0] == sizeof(int32_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_F16:
+            {
+                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_F32:
+            {
+                assert(tensor->nb[0] == sizeof(float));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+    return tensor;
+}
+
+int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
+    switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_I8:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+                return ((int8_t *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_I16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+                return ((int16_t *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_I32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+                return ((int32_t *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+                return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(float));
+                return ((float *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+    return 0.0f;
+}
+
+void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
+    switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_I8:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+                ((int8_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+                ((int16_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+                ((int32_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+                ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(float));
+                ((float *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
+    switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_I8:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+                return ((int8_t *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_I16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+                return ((int16_t *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_I32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+                return ((int32_t *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+                return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(float));
+                return ((float *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+    return 0.0f;
+}
+
+void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
+    switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_I8:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+                ((int8_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+                ((int16_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+                ((int32_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+                ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(float));
+                ((float *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+void * ggml_get_data(const struct ggml_tensor * tensor) {
+    return tensor->data;
+}
+
+float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
+    assert(tensor->type == GGML_TYPE_F32);
+    return (float *)(tensor->data);
+}
+
+struct ggml_tensor * ggml_view_tensor(
+        struct ggml_context * ctx,
+        const struct ggml_tensor * src) {
+    return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// ggml_dup
+
+struct ggml_tensor * ggml_dup_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_DUP;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_dup(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    return ggml_dup_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_dup_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    return ggml_dup_impl(ctx, a, true);
+}
+
+// ggml_add
+
+struct ggml_tensor * ggml_add_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        bool inplace) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_ADD;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_add(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_add_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_add_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_add_impl(ctx, a, b, true);
+}
+
+// ggml_sub
+
+struct ggml_tensor * ggml_sub_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        bool inplace) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SUB;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sub(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_sub_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_sub_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_sub_impl(ctx, a, b, true);
+}
+
+// ggml_mul
+
+struct ggml_tensor * ggml_mul_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        bool inplace) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
+    }
+
+    if (inplace) {
+        GGML_ASSERT(is_node == false);
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_MUL;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_mul(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_mul_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_mul_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_mul_impl(ctx, a, b, true);
+}
+
+// ggml_div
+
+struct ggml_tensor * ggml_div_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        bool inplace) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
+    }
+
+    if (inplace) {
+        GGML_ASSERT(is_node == false);
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_DIV;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_div(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_div_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_div_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_div_impl(ctx, a, b, true);
+}
+
+// ggml_sqr
+
+struct ggml_tensor * ggml_sqr_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SQR;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sqr(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sqr_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sqr_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sqr_impl(ctx, a, true);
+}
+
+// ggml_sqrt
+
+struct ggml_tensor * ggml_sqrt_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SQRT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sqrt(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sqrt_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sqrt_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sqrt_impl(ctx, a, true);
+}
+
+// ggml_sum
+
+struct ggml_tensor * ggml_sum(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
+
+    result->op   = GGML_OP_SUM;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_mean
+
+struct ggml_tensor * ggml_mean(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement
+        is_node = true;
+    }
+
+    int ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne);
+
+    result->op   = GGML_OP_MEAN;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_repeat
+
+struct ggml_tensor * ggml_repeat(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    GGML_ASSERT(ggml_can_repeat(a, b));
+
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    if (ggml_are_same_shape(a, b) && !is_node) {
+        return a;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
+
+    result->op   = GGML_OP_REPEAT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_abs
+
+struct ggml_tensor * ggml_abs_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_ABS;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_abs(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_abs_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_abs_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_abs_impl(ctx, a, true);
+}
+
+
+// ggml_sgn
+
+struct ggml_tensor * ggml_sgn_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SGN;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sgn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sgn_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sgn_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sgn_impl(ctx, a, true);
+}
+
+// ggml_neg
+
+struct ggml_tensor * ggml_neg_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_NEG;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_neg(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_neg_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_neg_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_neg_impl(ctx, a, true);
+}
+
+// ggml_step
+
+struct ggml_tensor * ggml_step_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_STEP;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_step(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_step_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_step_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_step_impl(ctx, a, true);
+}
+
+// ggml_relu
+
+struct ggml_tensor * ggml_relu_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_RELU;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_relu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_relu_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_relu_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_relu_impl(ctx, a, true);
+}
+
+// ggml_gelu
+
+struct ggml_tensor * ggml_gelu_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_GELU;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_gelu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_gelu_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_gelu_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_gelu_impl(ctx, a, true);
+}
+
+// ggml_silu
+
+struct ggml_tensor * ggml_silu_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SILU;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_silu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_silu_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_silu_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_silu_impl(ctx, a, true);
+}
+
+// ggml_norm
+
+struct ggml_tensor * ggml_norm_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_NORM;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store epsilon here?
+
+    return result;
+}
+
+struct ggml_tensor * ggml_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_norm_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_norm_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_norm_impl(ctx, a, true);
+}
+
+// ggml_mul_mat
+
+struct ggml_tensor * ggml_mul_mat(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_can_mul_mat(a, b));
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        is_node = true;
+    }
+
+    const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
+
+    result->op   = GGML_OP_MUL_MAT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_scale
+
+struct ggml_tensor * ggml_scale_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool inplace) {
+    GGML_ASSERT(ggml_is_scalar(b));
+    GGML_ASSERT(ggml_is_padded_1d(a));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: when implement backward, fix this:
+    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    result->op   = GGML_OP_SCALE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_scale(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_scale_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_scale_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_scale_impl(ctx, a, b, true);
+}
+
+// ggml_cpy
+
+struct ggml_tensor * ggml_cpy_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool inplace) {
+    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // make a view of the destination
+    struct ggml_tensor * result = ggml_view_tensor(ctx, b);
+
+    result->op   = GGML_OP_CPY;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_cpy(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_cpy_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_cpy_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_cpy_impl(ctx, a, b, true);
+}
+
+// ggml_reshape
+
+struct ggml_tensor * ggml_reshape(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_is_contiguous(b));
+    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
+
+    result->op   = GGML_OP_RESHAPE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_reshape_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
+
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    const int ne[2] = { ne0, ne1 };
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data);
+
+    result->op   = GGML_OP_RESHAPE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_reshape_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1,
+        int                   ne2) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
+
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    const int ne[3] = { ne0, ne1, ne2 };
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data);
+
+    result->op   = GGML_OP_RESHAPE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_view_1d
+
+struct ggml_tensor * ggml_view_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        size_t                offset) {
+    if (a->grad) {
+        GGML_ASSERT(false); // gradient propagation is not supported
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
+
+    result->op   = GGML_OP_VIEW;
+    result->grad = NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store the offset here?
+
+    return result;
+}
+
+// ggml_view_2d
+
+struct ggml_tensor * ggml_view_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1,
+        size_t                nb1,
+        size_t                offset) {
+    if (a->grad) {
+        GGML_ASSERT(false); // gradient propagation is not supported
+    }
+
+    const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
+
+    result->nb[1] = nb1;
+    result->nb[2] = result->nb[1]*ne1;
+    result->nb[3] = result->nb[2];
+
+    result->op   = GGML_OP_VIEW;
+    result->grad = NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store the offset here?
+
+    return result;
+}
+
+// ggml_permute
+
+struct ggml_tensor * ggml_permute(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   axis0,
+        int                   axis1,
+        int                   axis2,
+        int                   axis3) {
+    GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
+    GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
+    GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
+    GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
+
+    GGML_ASSERT(axis0 != axis1);
+    GGML_ASSERT(axis0 != axis2);
+    GGML_ASSERT(axis0 != axis3);
+    GGML_ASSERT(axis1 != axis2);
+    GGML_ASSERT(axis1 != axis3);
+    GGML_ASSERT(axis2 != axis3);
+
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    int ne[GGML_MAX_DIMS];
+    int nb[GGML_MAX_DIMS];
+
+    ne[axis0] = a->ne[0];
+    ne[axis1] = a->ne[1];
+    ne[axis2] = a->ne[2];
+    ne[axis3] = a->ne[3];
+
+    nb[axis0] = a->nb[0];
+    nb[axis1] = a->nb[1];
+    nb[axis2] = a->nb[2];
+    nb[axis3] = a->nb[3];
+
+    result->ne[0] = ne[0];
+    result->ne[1] = ne[1];
+    result->ne[2] = ne[2];
+    result->ne[3] = ne[3];
+
+    result->nb[0] = nb[0];
+    result->nb[1] = nb[1];
+    result->nb[2] = nb[2];
+    result->nb[3] = nb[3];
+
+    result->op   = GGML_OP_PERMUTE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store the permutation here?
+
+    return result;
+}
+
+// ggml_transpose
+
+struct ggml_tensor * ggml_transpose(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    result->ne[0] = a->ne[1];
+    result->ne[1] = a->ne[0];
+
+    result->nb[0] = a->nb[1];
+    result->nb[1] = a->nb[0];
+
+    result->op   = GGML_OP_TRANSPOSE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_get_rows
+
+struct ggml_tensor * ggml_get_rows(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: implement non F32 return
+    //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
+    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
+
+    result->op   = GGML_OP_GET_ROWS;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_diag_mask_inf
+
+struct ggml_tensor * ggml_diag_mask_inf(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: when implement backward, fix this:
+    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+    struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
+
+    result->op   = GGML_OP_DIAG_MASK_INF;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_soft_max
+
+struct ggml_tensor * ggml_soft_max(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: when implement backward, fix this:
+    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    result->op   = GGML_OP_SOFT_MAX;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_rope
+
+struct ggml_tensor * ggml_rope(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        int                   mode) {
+    GGML_ASSERT(n_past >= 0);
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: when implement backward, fix this:
+    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+    ((int32_t *) b->data)[0] = n_past;
+    ((int32_t *) b->data)[1] = n_dims;
+    ((int32_t *) b->data)[2] = mode;
+
+    result->op   = GGML_OP_ROPE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_conv_1d_1s
+
+struct ggml_tensor * ggml_conv_1d_1s(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_is_matrix(b));
+    GGML_ASSERT(a->ne[1] == b->ne[1]);
+    GGML_ASSERT(a->ne[3] == 1);
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    const int ne[4] = { b->ne[0], a->ne[2], 1, 1, };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
+
+    result->op   = GGML_OP_CONV_1D_1S;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_conv_1d_2s
+
+struct ggml_tensor * ggml_conv_1d_2s(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_is_matrix(b));
+    GGML_ASSERT(a->ne[1] == b->ne[1]);
+    GGML_ASSERT(a->ne[3] == 1);
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    const int ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
+
+    result->op   = GGML_OP_CONV_1D_2S;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_flash_attn
+
+struct ggml_tensor * ggml_flash_attn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        bool                  masked) {
+    GGML_ASSERT(ggml_can_mul_mat(k, q));
+    // TODO: check if vT can be multiplied by (k*qT)
+
+    bool is_node = false;
+
+    if (q->grad || k->grad || v->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne);
+
+    result->op   = GGML_OP_FLASH_ATTN;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = q;
+    result->src1 = k;
+    result->opt[0] = v;
+    result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0);
+
+    return result;
+}
+
+// ggml_flash_ff
+
+struct ggml_tensor * ggml_flash_ff(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b0,
+        struct ggml_tensor  * b1,
+        struct ggml_tensor  * c0,
+        struct ggml_tensor  * c1) {
+    GGML_ASSERT(ggml_can_mul_mat(b0, a));
+    // TODO: more checks
+
+    bool is_node = false;
+
+    if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne);
+
+    result->op   = GGML_OP_FLASH_FF;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b0;
+    result->opt[0] = b1;
+    result->opt[1] = c0;
+    result->opt[2] = c1;
+
+    return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_set_param(
+        struct ggml_context * ctx,
+        struct ggml_tensor * tensor) {
+    tensor->is_param = true;
+
+    GGML_ASSERT(tensor->grad == NULL);
+    tensor->grad = ggml_dup_tensor(ctx, tensor);
+}
+
+// ggml_compute_forward_dup
+
+static void ggml_compute_forward_dup_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(params->ith == 0);
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    if (ggml_is_contiguous(src0) && src0->type == dst->type) {
+        memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
+        return;
+    }
+
+    if (src0->nb[0] == sizeof(ggml_fp16_t)) {
+        if (dst->type == GGML_TYPE_F16) {
+            int id = 0;
+            const size_t rs = ne00*nb00;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                        char * dst_ptr = (char *) dst->data + id*rs;
+
+                        memcpy(dst_ptr, src0_ptr, rs);
+
+                        id++;
+                    }
+                }
+            }
+        } else if (dst->type == GGML_TYPE_F32) {
+            int id = 0;
+            float * dst_ptr = (float *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else {
+            GGML_ASSERT(false); // TODO: implement
+        }
+    } else {
+        //printf("%s: this is not optimal - fix me\n", __func__);
+
+        if (dst->type == GGML_TYPE_F32) {
+            int id = 0;
+            float * dst_ptr = (float *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else if (dst->type == GGML_TYPE_F16) {
+            int id = 0;
+            ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = *src0_ptr;
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else {
+            GGML_ASSERT(false); // TODO: implement
+        }
+    }
+}
+
+static void ggml_compute_forward_dup_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(params->ith == 0);
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    if (ggml_is_contiguous(src0) && src0->type == dst->type) {
+        memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
+        return;
+    }
+
+    if (src0->nb[0] == sizeof(float)) {
+        if (dst->type == GGML_TYPE_F32) {
+            int id = 0;
+            const size_t rs = ne00*nb00;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                        char * dst_ptr = (char *) dst->data + id*rs;
+
+                        memcpy(dst_ptr, src0_ptr, rs);
+
+                        id++;
+                    }
+                }
+            }
+        } else if (dst->type == GGML_TYPE_F16) {
+            int id = 0;
+            ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else {
+            GGML_ASSERT(false); // TODO: implement
+        }
+    } else {
+        //printf("%s: this is not optimal - fix me\n", __func__);
+
+        if (dst->type == GGML_TYPE_F32) {
+            int id = 0;
+            float * dst_ptr = (float *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = *src0_ptr;
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else if (dst->type == GGML_TYPE_F16) {
+            int id = 0;
+            ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else {
+            GGML_ASSERT(false); // TODO: implement
+        }
+    }
+}
+
+static void ggml_compute_forward_dup(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_dup_f16(params, src0, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_dup_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_add
+
+static void ggml_compute_forward_add_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (nb10 == sizeof(float)) {
+        const int j0 = (n/nth)*ith;
+        const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
+
+        for (int j = j0; j < j1; j++) {
+            ggml_vec_add_f32(nc,
+                    (float *) ((char *) dst->data  + j*nb1),
+                    (float *) ((char *) src0->data + j*nb01),
+                    (float *) ((char *) src1->data + j*nb11));
+        }
+    } else {
+        // src1 is not contiguous
+        for (int j = ith; j < n; j += nth) {
+            float * dst_ptr  = (float *) ((char *) dst->data  + j*nb1);
+            float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+            for (int i = 0; i < nc; i++) {
+                float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
+
+                dst_ptr[i] = src0_ptr[i] + *src1_ptr;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_add(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_add_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_sub
+
+static void ggml_compute_forward_sub_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+    assert(src1->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sub_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])),
+                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sub(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sub_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_mul
+
+static void ggml_compute_forward_mul_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+    assert(src1->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_mul_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])),
+                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_mul(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_mul_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_div
+
+static void ggml_compute_forward_div_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+    assert(src1->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_div_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])),
+                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_div(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_div_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_sqr
+
+static void ggml_compute_forward_sqr_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n     = ggml_nrows(src0);
+    const int nc    = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sqr_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sqr(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sqr_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_sqrt
+
+static void ggml_compute_forward_sqrt_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sqrt_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sqrt(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sqrt_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_sum
+
+static void ggml_compute_forward_sum_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_is_scalar(dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    assert(ggml_is_scalar(dst));
+    assert(src0->nb[0] == sizeof(float));
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    for (int i03 = 0; i03 < ne03; i03++) {
+        for (int i02 = 0; i02 < ne02; i02++) {
+            for (int i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f32(ne00,
+                        (float *) (dst->data),
+                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_sum(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sum_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_mean
+
+static void ggml_compute_forward_mean_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    assert(src0->nb[0] == sizeof(float));
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const int ne0 = dst->ne[0];
+    const int ne1 = dst->ne[1];
+    const int ne2 = dst->ne[2];
+    const int ne3 = dst->ne[3];
+
+    assert(ne0 == 1);
+    assert(ne1 == ne01);
+    assert(ne2 == ne02);
+    assert(ne3 == ne03);
+
+    UNUSED(ne0);
+    UNUSED(ne1);
+    UNUSED(ne2);
+    UNUSED(ne3);
+
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    for (int i03 = 0; i03 < ne03; i03++) {
+        for (int i02 = 0; i02 < ne02; i02++) {
+            for (int i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f32(ne00,
+                        (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+
+                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_mean(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_mean_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_repeat
+
+static void ggml_compute_forward_repeat_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_can_repeat(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // TODO: implement support for rank > 2 tensors
+    assert(src0->ne[2] == 1);
+    assert(src0->ne[3] == 1);
+    assert( dst->ne[2] == 1);
+    assert( dst->ne[3] == 1);
+
+    const int nc  = dst->ne[0];
+    const int nr  = dst->ne[1];
+    const int nc0 = src0->ne[0];
+    const int nr0 = src0->ne[1];
+    const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
+
+    // TODO: support for transposed / permuted tensors
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    // TODO: maybe this is not optimal?
+    for (int i = 0; i < nrr; i++) {
+        for (int j = 0; j < ncr; j++) {
+            for (int k = 0; k < nr0; k++) {
+                ggml_vec_cpy_f32(nc0,
+                        (float *) ((char *)  dst->data + (i*nr0 + k)*( dst->nb[1]) + j*nc0*( dst->nb[0])),
+                        (float *) ((char *) src0->data + (        k)*(src0->nb[1])));
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_repeat(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_repeat_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_abs
+
+static void ggml_compute_forward_abs_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_abs_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_abs(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_abs_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_sgn
+
+static void ggml_compute_forward_sgn_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sgn_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sgn(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sgn_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_neg
+
+static void ggml_compute_forward_neg_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_neg_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_neg(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_neg_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_step
+
+static void ggml_compute_forward_step_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_step_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_step(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_step_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_relu
+
+static void ggml_compute_forward_relu_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_relu_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_relu(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_relu_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_gelu
+
+static void ggml_compute_forward_gelu_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_gelu(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gelu_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+    //printf("XXXXXXXX gelu\n");
+}
+
+// ggml_compute_forward_silu
+
+static void ggml_compute_forward_silu_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_silu(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_silu_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+
+// ggml_compute_forward_norm
+
+static void ggml_compute_forward_norm_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    const ggml_float eps = 1e-5f; // TODO: make this a parameter
+
+    // TODO: optimize
+    for (int i03 = 0; i03 < ne03; i03++) {
+        for (int i02 = 0; i02 < ne02; i02++) {
+            for (int i01 = ith; i01 < ne01; i01 += nth) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float mean = 0.0;
+                for (int i00 = 0; i00 < ne00; i00++) {
+                    mean += x[i00];
+                }
+
+                mean /= ne00;
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                ggml_float sum2 = 0.0;
+                for (int i00 = 0; i00 < ne00; i00++) {
+                    ggml_float v = x[i00] - mean;
+                    y[i00] = v;
+                    sum2 += v*v;
+                }
+
+                const float scale = 1.0/sqrt(sum2/ne00 + eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_norm(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_norm_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_mul_mat
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+// helper function to determine if it is better to use BLAS or not
+// for large matrices, BLAS is faster
+static bool ggml_compute_forward_mul_mat_use_blas(
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    UNUSED(src0);
+
+    const int ne10 = src1->ne[0];
+
+    const int ne0 = dst->ne[0];
+    const int ne1 = dst->ne[1];
+
+    // TODO: find the optimal values for these
+    if (ggml_is_contiguous(src0) &&
+        ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
+        //printf("BLAS: %d %d %d\n", ne0, ne1, ne10);
+        return true;
+    }
+
+    return false;
+}
+#endif
+
+static void ggml_compute_forward_mul_mat_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    const int ne3  = dst->ne[3];
+    const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    assert(ne02 == ne12);
+    assert(ne03 == ne13);
+    assert(ne2  == ne12);
+    assert(ne3  == ne13);
+
+    // TODO: we don't support permuted src0
+    assert(nb00 == sizeof(float) || nb01 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    assert(nb0 == sizeof(float));
+    assert(nb0 <= nb1);
+    assert(nb1 <= nb2);
+    assert(nb2 <= nb3);
+
+    assert(ne0 == ne01);
+    assert(ne1 == ne11);
+    assert(ne2 == ne02);
+    assert(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+    //
+    // nb00 <  nb01 - src0 is transposed
+    //   compute by src0 columns
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
+        GGML_ASSERT(nb10 == sizeof(float));
+
+        if (params->ith != 0) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_INIT) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_FINALIZE) {
+            return;
+        }
+
+        for (int i03 = 0; i03 < ne03; i03++) {
+            for (int i02 = 0; i02 < ne02; i02++) {
+                const float * x = (float *) (src0->data);
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // zT = y * xT
+                {
+                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                            ne11, ne01, ne10,
+                            1.0f,    y, ne10,
+                                     x, ne10,
+                            0.0f,    d, ne01);
+                }
+            }
+        }
+
+        //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
+
+        return;
+    }
+#endif
+
+    if (params->type == GGML_TASK_INIT) {
+        if (nb01 >= nb00) {
+            return;
+        }
+
+        // TODO: fix this memset (wsize is overestimated)
+        memset(params->wdata, 0, params->wsize);
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        if (nb01 >= nb00) {
+            return;
+        }
+
+        // TODO: fix this memset (wsize is overestimated)
+        //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
+
+        float * const wdata = params->wdata;
+
+        // cols per thread
+        const int dc = (ne + nth - 1)/nth;
+
+        // col range for this thread
+        const int ic0 = dc*ith;
+        const int ic1 = MIN(ic0 + dc, ne);
+
+        ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
+
+        for (int k = 1; k < nth; k++) {
+            ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
+        }
+
+        return;
+    }
+
+    if (nb01 >= nb00) {
+        // TODO: do not support transposed src1
+        assert(nb10 == sizeof(float));
+
+        // parallelize by src0 rows using ggml_vec_dot_f32
+
+        // total rows in src0
+        const int nr = ne01*ne02*ne03;
+
+        // rows per thread
+        const int dr = (nr + nth - 1)/nth;
+
+        // row range for this thread
+        const int ir0 = dr*ith;
+        const int ir1 = MIN(ir0 + dr, nr);
+
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0 indices
+            const int i03 = ir/(ne02*ne01);
+            const int i02 = (ir - i03*ne02*ne01)/ne01;
+            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            for (int ic = 0; ic < ne11; ++ic) {
+                // src1 indices
+                const int i13 = i03;
+                const int i12 = i02;
+                const int i11 = ic;
+
+                // dst indices
+                const int i0 = i01;
+                const int i1 = i11;
+                const int i2 = i02;
+                const int i3 = i03;
+
+                ggml_vec_dot_f32(ne00,
+                        (float *) ((char *)  dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+                        (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
+                        (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
+            }
+        }
+    } else {
+        // parallelize by src1 columns using ggml_vec_mad_f32
+        // each thread has its own work data
+        // during FINALIZE we accumulate all work data into dst
+
+        // total columns in src1
+        const int nc = ne10;
+
+        // columns per thread
+        const int dc = (nc + nth - 1)/nth;
+
+        // column range for this thread
+        const int ic0 = dc*ith;
+        const int ic1 = MIN(ic0 + dc, nc);
+
+        // work data for thread
+        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
+        float * const wdata = params->wdata;
+
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                for (int i11 = 0; i11 < ne11; ++i11) {
+                    for (int ic = ic0; ic < ic1; ++ic) {
+                        // src1 indices
+                        const int i10 = ic;
+
+                        // src0 indices
+                        const int i03 = i13;
+                        const int i02 = i12;
+                        const int i00 = ic;
+
+                        // dst indices
+                        const int i1 = i11;
+                        const int i2 = i12;
+                        const int i3 = i13;
+
+                        assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+
+                        ggml_vec_mad_f32(ne01,
+                                (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0),
+                                (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)),
+                               *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)));
+                    }
+                }
+            }
+        }
+    }
+
+    //int64_t t1 = ggml_perf_time_us();
+    //static int64_t acc = 0;
+    //acc += t1 - t0;
+    //if (t1 - t0 > 10) {
+    //    printf("\n");
+    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
+
+    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+    //}
+}
+
+static void ggml_compute_forward_mul_mat_f16_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    const int ne3  = dst->ne[3];
+    const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    // TODO: we don't support permuted src0
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ne0 == ne01);
+    GGML_ASSERT(ne1 == ne11);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+    //
+    // nb00 <  nb01 - src0 is transposed
+    //   compute by src0 columns
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
+        GGML_ASSERT(nb10 == sizeof(float));
+
+        if (params->ith != 0) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_INIT) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_FINALIZE) {
+            return;
+        }
+
+        float * const wdata = params->wdata;
+
+        for (int i03 = 0; i03 < ne03; i03++) {
+            for (int i02 = 0; i02 < ne02; i02++) {
+                {
+                    int id = 0;
+                    for (int i01 = 0; i01 < ne01; ++i01) {
+                        for (int i00 = 0; i00 < ne00; ++i00) {
+                            wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
+                        }
+                    }
+                }
+
+                const float * x = wdata;
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                //      float * z =                          wdata + ne00*ne01;
+
+                // z = x * yT
+                //{
+                //    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                //            ne01, ne11, ne00,
+                //            1.0f, x, ne00,
+                //                  y, ne00,
+                //            0.0f, z, ne11);
+                //}
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // transpose z
+                //for (int j = 0; j < ne11; ++j) {
+                //    for (int i = 0; i < ne01; ++i) {
+                //        d[j*ne01 + i] = z[i*ne11 + j];
+                //    }
+                //}
+
+                {
+#if 1
+                    // zT = y * xT
+                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                            ne11, ne01, ne10,
+                            1.0f,    y, ne00,
+                                     x, ne00,
+                            0.0f,    d, ne01);
+#else
+                    // zT = (xT * y)T
+                    cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
+                            ne01, ne11, ne10,
+                            1.0f,    x, ne00,
+                                     y, ne00,
+                            0.0f,    d, ne01);
+#endif
+                }
+            }
+        }
+
+        /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
+
+        return;
+    }
+#endif
+
+    if (params->type == GGML_TASK_INIT) {
+        if (nb01 >= nb00) {
+            ggml_fp16_t * const wdata = params->wdata;
+
+            int id = 0;
+            for (int i13 = 0; i13 < ne13; ++i13) {
+                for (int i12 = 0; i12 < ne12; ++i12) {
+                    for (int i11 = 0; i11 < ne11; ++i11) {
+                        for (int i10 = 0; i10 < ne10; ++i10) {
+                            wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
+                        }
+                    }
+                }
+            }
+
+            GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
+
+            return;
+        }
+
+        // TODO: fix this memset (wsize is overestimated)
+        memset(params->wdata, 0, params->wsize);
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        if (nb01 >= nb00) {
+            return;
+        }
+
+        // TODO: fix this memset (wsize is overestimated)
+        //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
+
+        ggml_fp16_t * const wdata = params->wdata;
+
+        // cols per thread
+        const int dc = (ne + nth - 1)/nth;
+
+        // col range for this thread
+        const int ic0 = dc*ith;
+        const int ic1 = MIN(ic0 + dc, ne);
+
+        for (int i = ic0; i < ic1; ++i) {
+            ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]);
+        }
+
+        for (int k = 1; k < nth; k++) {
+            for (int i = ic0; i < ic1; ++i) {
+                ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]);
+            }
+        }
+
+        return;
+    }
+
+    if (nb01 >= nb00) {
+        // fp16 -> half the size, so divide by 2
+        // TODO: do not support transposed src1
+        assert(nb10/2 == sizeof(ggml_fp16_t));
+
+        // parallelize by src0 rows using ggml_vec_dot_f16
+
+        // total rows in src0
+        const int nr = ne01*ne02*ne03;
+
+        // rows per thread
+        const int dr = (nr + nth - 1)/nth;
+
+        // row range for this thread
+        const int ir0 = dr*ith;
+        const int ir1 = MIN(ir0 + dr, nr);
+
+        ggml_fp16_t * wdata = params->wdata;
+
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0 indices
+            const int i03 = ir/(ne02*ne01);
+            const int i02 = (ir - i03*ne02*ne01)/ne01;
+            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int i13 = i03;
+            const int i12 = i02;
+
+            const int i0 = i01;
+            const int i2 = i02;
+            const int i3 = i03;
+
+            ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+            ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
+
+            float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+
+            assert(ne00 % 32 == 0);
+
+            for (int ic = 0; ic < ne11; ++ic) {
+                ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
+            }
+        }
+    } else {
+        // parallelize by src1 columns using ggml_vec_mad_f16
+        // each thread has its own work data
+        // during FINALIZE we accumulate all work data into dst
+
+        // total columns in src1
+        const int nc = ne10;
+
+        // columns per thread
+        const int dc = (nc + nth - 1)/nth;
+
+        // column range for this thread
+        const int ic0 = dc*ith;
+        const int ic1 = MIN(ic0 + dc, nc);
+
+        // work data for thread
+        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
+        ggml_fp16_t * const wdata = params->wdata;
+
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                for (int i11 = 0; i11 < ne11; ++i11) {
+                    // dst indices
+                    const int i1 = i11;
+                    const int i2 = i12;
+                    const int i3 = i13;
+
+                    ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
+
+                    for (int ic = ic0; ic < ic1; ++ic) {
+                        // src1 indices
+                        const int i10 = ic;
+
+                        // src0 indices
+                        const int i03 = i13;
+                        const int i02 = i12;
+                        const int i00 = ic;
+
+                        assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+
+                        ggml_fp16_t * src0_col =  (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
+                        float         src1_val = *      (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+
+                        ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val);
+                    }
+                }
+            }
+        }
+    }
+
+    //int64_t t1 = ggml_time_us();
+    //static int64_t acc = 0;
+    //acc += t1 - t0;
+    //if (t1 - t0 > 10) {
+    //    printf("\n");
+    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+
+    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+    //}
+}
+
+static void ggml_compute_forward_mul_mat_q4_0_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    const int ne3  = dst->ne[3];
+    const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    // TODO: we don't support permuted src0
+    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ne0 == ne01);
+    GGML_ASSERT(ne1 == ne11);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+    //
+    // nb00 <  nb01 - src0 is transposed
+    //   compute by src0 columns
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
+        GGML_ASSERT(nb10 == sizeof(float));
+
+        if (params->ith != 0) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_INIT) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_FINALIZE) {
+            return;
+        }
+
+        float * const wdata = params->wdata;
+
+        for (int i03 = 0; i03 < ne03; i03++) {
+            for (int i02 = 0; i02 < ne02; i02++) {
+                {
+                    int id = 0;
+                    for (int i01 = 0; i01 < ne01; ++i01) {
+                        //for (int i00 = 0; i00 < ne00; ++i00) {
+                        //    wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
+                        //}
+                        dequantize_row_q4_0((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
+                        id += ne00;
+                    }
+                }
+
+                const float * x = wdata;
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                //      float * z =                          wdata + ne00*ne01;
+
+                // z = x * yT
+                //{
+                //    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                //            ne01, ne11, ne00,
+                //            1.0f, x, ne00,
+                //                  y, ne00,
+                //            0.0f, z, ne11);
+                //}
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // transpose z
+                //for (int j = 0; j < ne11; ++j) {
+                //    for (int i = 0; i < ne01; ++i) {
+                //        d[j*ne01 + i] = z[i*ne11 + j];
+                //    }
+                //}
+
+                {
+#if 1
+                    // zT = y * xT
+                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                            ne11, ne01, ne10,
+                            1.0f,    y, ne00,
+                                     x, ne00,
+                            0.0f,    d, ne01);
+#else
+                    // zT = (xT * y)T
+                    cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
+                            ne01, ne11, ne10,
+                            1.0f,    x, ne00,
+                                     y, ne00,
+                            0.0f,    d, ne01);
+#endif
+                }
+            }
+        }
+
+        /*printf("CBLAS Q4_0 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
+
+        return;
+    }
+#endif
+
+    if (params->type == GGML_TASK_INIT) {
+        //printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth);
+        if (nb01 >= nb00) {
+            char * wdata = params->wdata;
+
+            for (int i13 = 0; i13 < ne13; ++i13) {
+                for (int i12 = 0; i12 < ne12; ++i12) {
+                    for (int i11 = 0; i11 < ne11; ++i11) {
+                        //for (int i10 = 0; i10 < ne10; ++i10) {
+                        //    wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
+                        //}
+                        quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+                        wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
+                    }
+                }
+            }
+
+            return;
+        }
+
+        // TODO: fix this memset (wsize is overestimated)
+        memset(params->wdata, 0, params->wsize);
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        if (nb01 >= nb00) {
+            return;
+        }
+
+        float * const wdata = params->wdata;
+
+        // cols per thread
+        const int dc = (ne + nth - 1)/nth;
+
+        // col range for this thread
+        const int ic0 = dc*ith;
+        const int ic1 = MIN(ic0 + dc, ne);
+
+        ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
+
+        for (int k = 1; k < nth; k++) {
+            ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
+        }
+
+        return;
+    }
+
+    if (nb01 >= nb00) {
+        // TODO: do not support transposed src1
+
+        // parallelize by src0 rows using ggml_vec_dot_q4_0
+
+        // total rows in src0
+        const int nr = ne01*ne02*ne03;
+
+        // rows per thread
+        const int dr = (nr + nth - 1)/nth;
+
+        // row range for this thread
+        const int ir0 = dr*ith;
+        const int ir1 = MIN(ir0 + dr, nr);
+
+        void * wdata = params->wdata;
+
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0 indices
+            const int i03 = ir/(ne02*ne01);
+            const int i02 = (ir - i03*ne02*ne01)/ne01;
+            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int i13 = i03;
+            const int i12 = i02;
+
+            const int i0 = i01;
+            const int i2 = i02;
+            const int i3 = i03;
+
+            void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+            char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]);
+
+            float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+
+            assert(ne00 % 32 == 0);
+
+            for (int ic = 0; ic < ne11; ++ic) {
+                ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0])));
+            }
+        }
+    } else {
+        //printf("AAAAA ith = %d, nth = %d\n", ith, nth);
+        // parallelize by src1 columns using ggml_vec_mad_q4_0
+        // each thread has its own work data
+        // during FINALIZE we accumulate all work data into dst
+
+        // total columns in src1
+        const int nc = ne10;
+
+        // columns per thread
+        const int dc = (nc + nth - 1)/nth;
+
+        // column range for this thread
+        const int ic0 = dc*ith;
+        const int ic1 = MIN(ic0 + dc, nc);
+
+        // work data for thread
+        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
+        float * const wdata = params->wdata;
+
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                for (int i11 = 0; i11 < ne11; ++i11) {
+                    // dst indices
+                    const int i1 = i11;
+                    const int i2 = i12;
+                    const int i3 = i13;
+
+                    float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
+
+                    for (int ic = ic0; ic < ic1; ++ic) {
+                        // src1 indices
+                        const int i10 = ic;
+
+                        // src0 indices
+                        const int i03 = i13;
+                        const int i02 = i12;
+                        const int i00 = ic;
+
+                        assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+
+                        void * src0_col =   (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
+                        float  src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+
+                        ggml_vec_mad_q4_0(ne01, dst_row, src0_col, src1_val);
+                    }
+                }
+            }
+        }
+    }
+
+    //int64_t t1 = ggml_time_us();
+    //static int64_t acc = 0;
+    //acc += t1 - t0;
+    //if (t1 - t0 > 10) {
+    //    printf("\n");
+    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+
+    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+    //}
+}
+
+static void ggml_compute_forward_mul_mat_q4_1_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    const int ne3  = dst->ne[3];
+    const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    // TODO: we don't support permuted src0
+    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ne0 == ne01);
+    GGML_ASSERT(ne1 == ne11);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+    //
+    // nb00 <  nb01 - src0 is transposed
+    //   compute by src0 columns
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
+        GGML_ASSERT(nb10 == sizeof(float));
+
+        if (params->ith != 0) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_INIT) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_FINALIZE) {
+            return;
+        }
+
+        float * const wdata = params->wdata;
+
+        for (int i03 = 0; i03 < ne03; i03++) {
+            for (int i02 = 0; i02 < ne02; i02++) {
+                {
+                    int id = 0;
+                    for (int i01 = 0; i01 < ne01; ++i01) {
+                        //for (int i00 = 0; i00 < ne00; ++i00) {
+                        //    wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
+                        //}
+                        dequantize_row_q4_1((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
+                        id += ne00;
+                    }
+                }
+
+                const float * x = wdata;
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                //      float * z =                          wdata + ne00*ne01;
+
+                // z = x * yT
+                //{
+                //    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                //            ne01, ne11, ne00,
+                //            1.0f, x, ne00,
+                //                  y, ne00,
+                //            0.0f, z, ne11);
+                //}
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // transpose z
+                //for (int j = 0; j < ne11; ++j) {
+                //    for (int i = 0; i < ne01; ++i) {
+                //        d[j*ne01 + i] = z[i*ne11 + j];
+                //    }
+                //}
+
+                {
+#if 1
+                    // zT = y * xT
+                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                            ne11, ne01, ne10,
+                            1.0f,    y, ne00,
+                                     x, ne00,
+                            0.0f,    d, ne01);
+#else
+                    // zT = (xT * y)T
+                    cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
+                            ne01, ne11, ne10,
+                            1.0f,    x, ne00,
+                                     y, ne00,
+                            0.0f,    d, ne01);
+#endif
+                }
+            }
+        }
+
+        //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
+
+        return;
+    }
+#endif
+
+    if (params->type == GGML_TASK_INIT) {
+        //printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth);
+        if (nb01 >= nb00) {
+            char * wdata = params->wdata;
+
+            for (int i13 = 0; i13 < ne13; ++i13) {
+                for (int i12 = 0; i12 < ne12; ++i12) {
+                    for (int i11 = 0; i11 < ne11; ++i11) {
+                        //for (int i10 = 0; i10 < ne10; ++i10) {
+                        //    wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
+                        //}
+                        quantize_row_q4_1((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+                        wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
+                    }
+                }
+            }
+
+            return;
+        }
+
+        // TODO: fix this memset (wsize is overestimated)
+        memset(params->wdata, 0, params->wsize);
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        if (nb01 >= nb00) {
+            return;
+        }
+
+        float * const wdata = params->wdata;
+
+        // cols per thread
+        const int dc = (ne + nth - 1)/nth;
+
+        // col range for this thread
+        const int ic0 = dc*ith;
+        const int ic1 = MIN(ic0 + dc, ne);
+
+        ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
+
+        for (int k = 1; k < nth; k++) {
+            ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
+        }
+
+        return;
+    }
+
+    if (nb01 >= nb00) {
+        // TODO: do not support transposed src1
+
+        // parallelize by src0 rows using ggml_vec_dot_q4_1
+
+        // total rows in src0
+        const int nr = ne01*ne02*ne03;
+
+        // rows per thread
+        const int dr = (nr + nth - 1)/nth;
+
+        // row range for this thread
+        const int ir0 = dr*ith;
+        const int ir1 = MIN(ir0 + dr, nr);
+
+        void * wdata = params->wdata;
+
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0 indices
+            const int i03 = ir/(ne02*ne01);
+            const int i02 = (ir - i03*ne02*ne01)/ne01;
+            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int i13 = i03;
+            const int i12 = i02;
+
+            const int i0 = i01;
+            const int i2 = i02;
+            const int i3 = i03;
+
+            void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+            char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]);
+
+            float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+
+            assert(ne00 % 32 == 0);
+
+            for (int ic = 0; ic < ne11; ++ic) {
+                ggml_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1])));
+            }
+        }
+    } else {
+        //printf("AAAAA ith = %d, nth = %d\n", ith, nth);
+        // parallelize by src1 columns using ggml_vec_mad_q4_1
+        // each thread has its own work data
+        // during FINALIZE we accumulate all work data into dst
+
+        // total columns in src1
+        const int nc = ne10;
+
+        // columns per thread
+        const int dc = (nc + nth - 1)/nth;
+
+        // column range for this thread
+        const int ic0 = dc*ith;
+        const int ic1 = MIN(ic0 + dc, nc);
+
+        // work data for thread
+        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
+        float * const wdata = params->wdata;
+
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                for (int i11 = 0; i11 < ne11; ++i11) {
+                    // dst indices
+                    const int i1 = i11;
+                    const int i2 = i12;
+                    const int i3 = i13;
+
+                    float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
+
+                    for (int ic = ic0; ic < ic1; ++ic) {
+                        // src1 indices
+                        const int i10 = ic;
+
+                        // src0 indices
+                        const int i03 = i13;
+                        const int i02 = i12;
+                        const int i00 = ic;
+
+                        assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+
+                        void * src0_col =   (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
+                        float  src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+
+                        ggml_vec_mad_q4_1(ne01, dst_row, src0_col, src1_val);
+                    }
+                }
+            }
+        }
+    }
+
+    //int64_t t1 = ggml_time_us();
+    //static int64_t acc = 0;
+    //acc += t1 - t0;
+    //if (t1 - t0 > 10) {
+    //    printf("\n");
+    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+
+    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+    //}
+}
+
+static void ggml_compute_forward_mul_mat(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                ggml_compute_forward_mul_mat_q4_0_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                ggml_compute_forward_mul_mat_q4_1_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+#if 0
+    if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
+        static int first = 8;
+        printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
+        printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
+        printf("dst:  ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+        if (first) {
+            --first;
+        } else {
+            for (int k = 0; k < dst->ne[1]; ++k) {
+                for (int j = 0; j < dst->ne[0]/16; ++j) {
+                    for (int i = 0; i < 16; ++i) {
+                        printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+                    }
+                    printf("\n");
+                }
+                printf("\n");
+            }
+            printf("\n");
+            exit(0);
+        }
+    } else {
+        printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
+        printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
+        printf("aaaa dst:  ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    }
+#endif
+}
+
+// ggml_compute_forward_scale
+
+static void ggml_compute_forward_scale_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // scale factor
+    const float v = *(float *) src1->data;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), v);
+    }
+}
+
+static void ggml_compute_forward_scale(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_scale_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_cpy
+
+static void ggml_compute_forward_cpy(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    ggml_compute_forward_dup(params, src0, dst);
+}
+
+// ggml_compute_forward_reshape
+
+static void ggml_compute_forward_reshape(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    // NOP
+    UNUSED(params);
+    UNUSED(src0);
+    UNUSED(dst);
+}
+
+// ggml_compute_forward_view
+
+static void ggml_compute_forward_view(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0) {
+    // NOP
+    UNUSED(params);
+    UNUSED(src0);
+}
+
+// ggml_compute_forward_permute
+
+static void ggml_compute_forward_permute(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0) {
+    // NOP
+    UNUSED(params);
+    UNUSED(src0);
+}
+
+// ggml_compute_forward_transpose
+
+static void ggml_compute_forward_transpose(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0) {
+    // NOP
+    UNUSED(params);
+    UNUSED(src0);
+}
+
+// ggml_compute_forward_get_rows
+
+static void ggml_compute_forward_get_rows_q4_0(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    assert( dst->ne[0] == nc);
+    assert( dst->ne[1] == nr);
+    assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        dequantize_row_q4_0(
+                (const void *) ((char *) src0->data + r*src0->nb[1]),
+                     (float *) ((char *)  dst->data + i*dst->nb[1]), nc);
+    }
+}
+
+static void ggml_compute_forward_get_rows_q4_1(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    assert( dst->ne[0] == nc);
+    assert( dst->ne[1] == nr);
+    assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        dequantize_row_q4_1(
+                (const void *) ((char *) src0->data + r*src0->nb[1]),
+                     (float *) ((char *)  dst->data + i*dst->nb[1]), nc);
+    }
+}
+
+static void ggml_compute_forward_get_rows_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    assert( dst->ne[0] == nc);
+    assert( dst->ne[1] == nr);
+    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        for (int j = 0; j < nc; ++j) {
+            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
+            ((float *) ((char *)  dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
+        }
+    }
+}
+
+static void ggml_compute_forward_get_rows_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    assert( dst->ne[0] == nc);
+    assert( dst->ne[1] == nr);
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        ggml_vec_cpy_f32(nc,
+                (float *) ((char *)  dst->data + i*dst->nb[1]),
+                (float *) ((char *) src0->data + r*src0->nb[1]));
+    }
+}
+
+static void ggml_compute_forward_get_rows(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                ggml_compute_forward_get_rows_q4_0(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                ggml_compute_forward_get_rows_q4_1(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+    //static bool first = true;
+    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    //if (first) {
+    //    first = false;
+    //} else {
+    //    for (int k = 0; k < dst->ne[1]; ++k) {
+    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
+    //            for (int i = 0; i < 16; ++i) {
+    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+    //            }
+    //            printf("\n");
+    //        }
+    //        printf("\n");
+    //    }
+    //    printf("\n");
+    //    exit(0);
+    //}
+}
+
+// ggml_compute_forward_diag_mask_inf
+
+static void ggml_compute_forward_diag_mask_inf_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 1);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n_past = ((int32_t *) src1->data)[0];
+
+    // TODO: handle transposed/permuted matrices
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+    const int nr = src0->ne[1];
+    const int nz = n/nr;
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int k = 0; k < nz; k++) {
+        for (int j = 0; j < nr; j++) {
+            for (int i = n_past; i < nc; i++) {
+                if (i > n_past + j) {
+                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = -INFINITY;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_diag_mask_inf(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_soft_max
+
+static void ggml_compute_forward_soft_max_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // TODO: handle transposed/permuted matrices
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float *p = (float *)((char *) dst->data + i1*dst->nb[1]);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(p[i]));
+        }
+#endif
+
+        float max = -INFINITY;
+        ggml_vec_max_f32(nc, &max, p);
+
+        ggml_float sum = 0.0;
+
+        uint16_t scvt;
+        for (int i = 0; i < nc; i++) {
+            if (p[i] == -INFINITY) {
+                p[i] = 0.0f;
+            } else {
+                //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
+                ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max);
+                memcpy(&scvt, &s, sizeof(scvt));
+                const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
+                sum += val;
+                p[i] = val;
+            }
+        }
+
+        assert(sum > 0.0f);
+
+        sum = 1.0/sum;
+        ggml_vec_scale_f32(nc, p, sum);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(p[i]));
+            assert(!isinf(p[i]));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_soft_max(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_soft_max_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_rope
+
+static void ggml_compute_forward_rope_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 3);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_dims = ((int32_t *) src1->data)[1];
+    const int mode   = ((int32_t *) src1->data)[2];
+
+    //const int ne0 = src0->ne[0];
+    const int ne1 = src0->ne[1];
+    const int ne2 = src0->ne[2];
+    const int ne3 = src0->ne[3];
+
+    const int nb0 = src0->nb[0];
+    const int nb1 = src0->nb[1];
+    const int nb2 = src0->nb[2];
+    const int nb3 = src0->nb[3];
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    assert(nb0 == sizeof(float));
+
+    // TODO: optimize
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
+            const int p = (mode == 0 ? n_past + i2 : i2);
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < n_dims; i0 += 2) {
+                    const double theta = pow(10000.0, ((double)-i0)/n_dims);
+
+                    const double cos_theta = cos(p*theta);
+                    const double sin_theta = sin(p*theta);
+
+                    const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                          float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+                    double x0 = src[0];
+                    double x1 = src[1];
+
+                    dst_data[0] = x0*cos_theta - x1*sin_theta;
+                    dst_data[1] = x0*sin_theta + x1*cos_theta;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rope_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 3);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_dims = ((int32_t *) src1->data)[1];
+    const int mode   = ((int32_t *) src1->data)[2];
+
+    //const int ne0 = src0->ne[0];
+    const int ne1 = src0->ne[1];
+    const int ne2 = src0->ne[2];
+    const int ne3 = src0->ne[3];
+
+    const int nb0 = src0->nb[0];
+    const int nb1 = src0->nb[1];
+    const int nb2 = src0->nb[2];
+    const int nb3 = src0->nb[3];
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    assert(nb0 == sizeof(ggml_fp16_t));
+
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
+            const int p = (mode == 0 ? n_past + i2 : i2);
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < n_dims; i0 += 2) {
+                    const double theta = pow(10000.0, ((double)-i0)/n_dims);
+
+                    const double cos_theta = cos(p*theta);
+                    const double sin_theta = sin(p*theta);
+
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                          ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+                    double x0 = ggml_fp16_to_fp32(src[0]);
+                    double x1 = ggml_fp16_to_fp32(src[1]);
+
+                    dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
+                    dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rope(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_rope_f16(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rope_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_conv_1d_1s
+
+static void ggml_compute_forward_conv_1d_1s_f16_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    //const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    //const int ne12 = src1->ne[2];
+    //const int ne13 = src1->ne[3];
+
+    //const int ne0  = dst->ne[0];
+    //const int ne1  = dst->ne[1];
+    //const int ne2  = dst->ne[2];
+    //const int ne3  = dst->ne[3];
+    //const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    //const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    //const int nb12 = src1->nb[2];
+    //const int nb13 = src1->nb[3];
+
+    //const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    //const int nb2  = dst->nb[2];
+    //const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00;
+    const int nh = nk/2;
+
+    const int ew0 = ggml_up32(ne01);
+
+    GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (params->type == GGML_TASK_INIT) {
+        // TODO: fix this memset (wsize is overestimated)
+        memset(params->wdata, 0, params->wsize);
+
+        // prepare kernel data (src0)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+            for (int i02 = 0; i02 < ne02; i02++) {
+                for (int i01 = 0; i01 < ne01; i01++) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
+                    for (int i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ew0 + i01] = src[i00];
+                    }
+                }
+            }
+        }
+
+        // prepare source data (src1)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
+
+            for (int i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                ggml_fp16_t * dst_data = wdata;
+                for (int i10 = 0; i10 < ne10; i10++) {
+                    dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
+                }
+            }
+        }
+
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // total rows in dst
+    const int nr = ne02;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        for (int i0 = 0; i0 < ne10; ++i0) {
+            dst_data[i0] = 0;
+            for (int k = -nh; k <= nh; k++) {
+                float v = 0.0f;
+                ggml_vec_dot_f16(ew0, &v,
+                        (ggml_fp16_t *) params->wdata +   i1*ew0*ne00 +      (nh + k)*ew0,
+                        (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
+
+                dst_data[i0] += v;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_conv_1d_1s_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    //const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    //const int ne12 = src1->ne[2];
+    //const int ne13 = src1->ne[3];
+
+    //const int ne0  = dst->ne[0];
+    //const int ne1  = dst->ne[1];
+    //const int ne2  = dst->ne[2];
+    //const int ne3  = dst->ne[3];
+    //const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    //const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    //const int nb12 = src1->nb[2];
+    //const int nb13 = src1->nb[3];
+
+    //const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    //const int nb2  = dst->nb[2];
+    //const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00;
+    const int nh = nk/2;
+
+    const int ew0 = ggml_up32(ne01);
+
+    GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (params->type == GGML_TASK_INIT) {
+        // TODO: fix this memset (wsize is overestimated)
+        memset(params->wdata, 0, params->wsize);
+
+        // prepare kernel data (src0)
+        {
+            float * const wdata = (float *) params->wdata + 0;
+
+            for (int i02 = 0; i02 < ne02; i02++) {
+                for (int i01 = 0; i01 < ne01; i01++) {
+                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    float * dst_data = wdata + i02*ew0*ne00;
+                    for (int i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ew0 + i01] = src[i00];
+                    }
+                }
+            }
+        }
+
+        // prepare source data (src1)
+        {
+            float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
+
+            for (int i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                float * dst_data = wdata;
+                for (int i10 = 0; i10 < ne10; i10++) {
+                    dst_data[(i10 + nh)*ew0 + i11] = src[i10];
+                }
+            }
+        }
+
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // total rows in dst
+    const int nr = ne02;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        for (int i0 = 0; i0 < ne10; ++i0) {
+            dst_data[i0] = 0;
+            for (int k = -nh; k <= nh; k++) {
+                float v = 0.0f;
+                ggml_vec_dot_f32(ew0, &v,
+                        (float *) params->wdata +   i1*ew0*ne00 +      (nh + k)*ew0,
+                        (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
+
+                dst_data[i0] += v;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_conv_1d_1s(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_conv_1d_1s_f16_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_conv_1d_2s
+
+static void ggml_compute_forward_conv_1d_2s_f16_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    //const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    //const int ne12 = src1->ne[2];
+    //const int ne13 = src1->ne[3];
+
+    //const int ne0  = dst->ne[0];
+    //const int ne1  = dst->ne[1];
+    //const int ne2  = dst->ne[2];
+    //const int ne3  = dst->ne[3];
+    //const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    //const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    //const int nb12 = src1->nb[2];
+    //const int nb13 = src1->nb[3];
+
+    //const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    //const int nb2  = dst->nb[2];
+    //const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00;
+    const int nh = nk/2;
+
+    const int ew0 = ggml_up32(ne01);
+
+    GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (params->type == GGML_TASK_INIT) {
+        // TODO: fix this memset (wsize is overestimated)
+        memset(params->wdata, 0, params->wsize);
+
+        // prepare kernel data (src0)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+            for (int i02 = 0; i02 < ne02; i02++) {
+                for (int i01 = 0; i01 < ne01; i01++) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
+                    for (int i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ew0 + i01] = src[i00];
+                    }
+                }
+            }
+        }
+
+        // prepare source data (src1)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
+
+            for (int i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                ggml_fp16_t * dst_data = wdata;
+                for (int i10 = 0; i10 < ne10; i10++) {
+                    dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
+                }
+            }
+        }
+
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // total rows in dst
+    const int nr = ne02;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        for (int i0 = 0; i0 < ne10; i0 += 2) {
+            dst_data[i0/2] = 0;
+            for (int k = -nh; k <= nh; k++) {
+                float v = 0.0f;
+                ggml_vec_dot_f16(ew0, &v,
+                        (ggml_fp16_t *) params->wdata +   i1*ew0*ne00 +      (nh + k)*ew0,
+                        (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
+
+                dst_data[i0/2] += v;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_conv_1d_2s_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    //const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    //const int ne12 = src1->ne[2];
+    //const int ne13 = src1->ne[3];
+
+    //const int ne0  = dst->ne[0];
+    //const int ne1  = dst->ne[1];
+    //const int ne2  = dst->ne[2];
+    //const int ne3  = dst->ne[3];
+    //const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    //const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    //const int nb12 = src1->nb[2];
+    //const int nb13 = src1->nb[3];
+
+    //const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    //const int nb2  = dst->nb[2];
+    //const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00;
+    const int nh = nk/2;
+
+    const int ew0 = ggml_up32(ne01);
+
+    GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (params->type == GGML_TASK_INIT) {
+        // TODO: fix this memset (wsize is overestimated)
+        memset(params->wdata, 0, params->wsize);
+
+        // prepare kernel data (src0)
+        {
+            float * const wdata = (float *) params->wdata + 0;
+
+            for (int i02 = 0; i02 < ne02; i02++) {
+                for (int i01 = 0; i01 < ne01; i01++) {
+                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    float * dst_data = wdata + i02*ew0*ne00;
+                    for (int i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ew0 + i01] = src[i00];
+                    }
+                }
+            }
+        }
+
+        // prepare source data (src1)
+        {
+            float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
+
+            for (int i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                float * dst_data = wdata;
+                for (int i10 = 0; i10 < ne10; i10++) {
+                    dst_data[(i10 + nh)*ew0 + i11] = src[i10];
+                }
+            }
+        }
+
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // total rows in dst
+    const int nr = ne02;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        for (int i0 = 0; i0 < ne10; i0 += 2) {
+            dst_data[i0/2] = 0;
+            for (int k = -nh; k <= nh; k++) {
+                float v = 0.0f;
+                ggml_vec_dot_f32(ew0, &v,
+                        (float *) params->wdata +   i1*ew0*ne00 +      (nh + k)*ew0,
+                        (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
+
+                dst_data[i0/2] += v;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_conv_1d_2s(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_conv_1d_2s_f16_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_flash_attn
+
+static void ggml_compute_forward_flash_attn_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const bool masked,
+             struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int neq0 = q->ne[0];
+    const int neq1 = q->ne[1];
+    const int neq2 = q->ne[2];
+    const int neq3 = q->ne[3];
+
+    const int nek0 = k->ne[0];
+    const int nek1 = k->ne[1];
+    //const int nek2 = k->ne[2];
+    //const int nek3 = k->ne[3];
+
+    //const int nev0 = v->ne[0];
+    const int nev1 = v->ne[1];
+    //const int nev2 = v->ne[2];
+    //const int nev3 = v->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    //const int ne2  = dst->ne[2];
+    //const int ne3  = dst->ne[3];
+
+    const int nbk0 = k->nb[0];
+    const int nbk1 = k->nb[1];
+    const int nbk2 = k->nb[2];
+    const int nbk3 = k->nb[3];
+
+    const int nbq0 = q->nb[0];
+    const int nbq1 = q->nb[1];
+    const int nbq2 = q->nb[2];
+    const int nbq3 = q->nb[3];
+
+    const int nbv0 = v->nb[0];
+    const int nbv1 = v->nb[1];
+    const int nbv2 = v->nb[2];
+    const int nbv3 = v->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int D = neq0;
+    const int N = neq1;
+    const int P = nek1 - N;
+    const int M = P + N;
+
+    const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
+
+    GGML_ASSERT(ne0 == D);
+    GGML_ASSERT(ne1 == N);
+    GGML_ASSERT(P >= 0);
+
+    GGML_ASSERT(nbq0 == sizeof(float));
+    GGML_ASSERT(nbk0 == sizeof(float));
+    GGML_ASSERT(nbv0 == sizeof(float));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev1 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nek1 == N + P);
+    GGML_ASSERT(nev1 == D);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    if (params->type == GGML_TASK_INIT) {
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // parallelize by q rows using ggml_vec_dot_f32
+
+    // total rows in q
+    const int nr = neq1*neq2*neq3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const float scale = 1.0/sqrt((double) D);
+
+    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int iq3 = ir/(neq2*neq1);
+        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+        float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
+
+        for (int i = M; i < Mup; ++i) {
+            S[i] = -INFINITY;
+        }
+
+        for (int ic = 0; ic < nek1; ++ic) {
+            // k indices
+            const int ik3 = iq3;
+            const int ik2 = iq2;
+            const int ik1 = ic;
+
+            // S indices
+            const int i1 = ik1;
+
+            ggml_vec_dot_f32(neq0,
+                    S + i1,
+                    (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                    (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+        }
+
+        // scale
+        ggml_vec_scale_f32(nek1, S, scale);
+
+        if (masked) {
+            for (int i = P; i < M; i++) {
+                if (i > P + iq1) {
+                    S[i] = -INFINITY;
+                }
+            }
+        }
+
+        // softmax
+        {
+            float max = -INFINITY;
+            ggml_vec_max_f32(M, &max, S);
+
+            float sum = 0.0f;
+            {
+#ifdef GGML_SOFT_MAX_ACCELERATE
+                max = -max;
+                vDSP_vsadd(S, 1, &max, S, 1, Mup);
+                vvexpf(S, S, &Mup);
+                ggml_vec_sum_f32(Mup, &sum, S);
+#else
+                uint16_t   scvt[GGML_SOFT_MAX_UNROLL];
+                ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
+
+                for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
+                    float * SS = S + i;
+
+                    for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
+                        if (SS[j] == -INFINITY) {
+                            SS[j] = 0.0f;
+                        } else {
+                            ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
+                            memcpy(&scvt[j], &s, sizeof(uint16_t));
+                            const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+                            sump[j] += val;
+                            SS[j] = val;
+                        }
+                    }
+                }
+
+                for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
+                    sum += sump[i];
+                }
+#endif
+            }
+
+            assert(sum > 0.0f);
+
+            sum = 1.0/sum;
+            ggml_vec_scale_f32(M, S, sum);
+
+#ifndef NDEBUG
+            for (int i = 0; i < M; ++i) {
+                assert(!isnan(S[i]));
+                assert(!isinf(S[i]));
+            }
+#endif
+        }
+
+        for (int ic = 0; ic < nev1; ++ic) {
+            // dst indices
+            const int i1 = iq1;
+            const int i2 = iq2;
+            const int i3 = iq3;
+
+            ggml_vec_dot_f32(nek1,
+                    (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
+                    (float *) ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                    S);
+        }
+    }
+}
+
+static void ggml_compute_forward_flash_attn_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const bool masked,
+             struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int neq0 = q->ne[0];
+    const int neq1 = q->ne[1];
+    const int neq2 = q->ne[2];
+    const int neq3 = q->ne[3];
+
+    const int nek0 = k->ne[0];
+    const int nek1 = k->ne[1];
+    //const int nek2 = k->ne[2];
+    //const int nek3 = k->ne[3];
+
+    //const int nev0 = v->ne[0];
+    const int nev1 = v->ne[1];
+    //const int nev2 = v->ne[2];
+    //const int nev3 = v->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    //const int ne2  = dst->ne[2];
+    //const int ne3  = dst->ne[3];
+
+    const int nbk0 = k->nb[0];
+    const int nbk1 = k->nb[1];
+    const int nbk2 = k->nb[2];
+    const int nbk3 = k->nb[3];
+
+    const int nbq0 = q->nb[0];
+    const int nbq1 = q->nb[1];
+    const int nbq2 = q->nb[2];
+    const int nbq3 = q->nb[3];
+
+    const int nbv0 = v->nb[0];
+    const int nbv1 = v->nb[1];
+    const int nbv2 = v->nb[2];
+    const int nbv3 = v->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int D = neq0;
+    const int N = neq1;
+    const int P = nek1 - N;
+    const int M = P + N;
+
+    const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
+
+    GGML_ASSERT(ne0 == D);
+    GGML_ASSERT(ne1 == N);
+    GGML_ASSERT(P >= 0);
+
+    GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev1 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nek1 == N + P);
+    GGML_ASSERT(nev1 == D);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    if (params->type == GGML_TASK_INIT) {
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // parallelize by q rows using ggml_vec_dot_f32
+
+    // total rows in q
+    const int nr = neq1*neq2*neq3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const float scale = 1.0/sqrt((double) D);
+
+    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int iq3 = ir/(neq2*neq1);
+        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+        float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
+
+        for (int i = M; i < Mup; ++i) {
+            S[i] = -INFINITY;
+        }
+
+        if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
+            for (int ic = 0; ic < nek1; ++ic) {
+                // k indices
+                const int ik3 = iq3;
+                const int ik2 = iq2;
+                const int ik1 = ic;
+
+                // S indices
+                const int i1 = ik1;
+
+                ggml_vec_dot_f16(neq0,
+                        S + i1,
+                        (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                        (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+            }
+        } else {
+            for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
+                // k indices
+                const int ik3 = iq3;
+                const int ik2 = iq2;
+                const int ik1 = ic;
+
+                // S indices
+                const int i1 = ik1;
+
+                ggml_vec_dot_f16_unroll(neq0, nbk1,
+                        S + i1,
+                        ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                        (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+            }
+        }
+
+        // scale
+        ggml_vec_scale_f32(nek1, S, scale);
+
+        if (masked) {
+            for (int i = P; i < M; i++) {
+                if (i > P + iq1) {
+                    S[i] = -INFINITY;
+                }
+            }
+        }
+
+        // softmax
+        {
+            float max = -INFINITY;
+            ggml_vec_max_f32(M, &max, S);
+
+            float sum = 0.0f;
+            {
+#ifdef GGML_SOFT_MAX_ACCELERATE
+                max = -max;
+                vDSP_vsadd(S, 1, &max, S, 1, Mup);
+                vvexpf(S, S, &Mup);
+                ggml_vec_sum_f32(Mup, &sum, S);
+#else
+                uint16_t   scvt[GGML_SOFT_MAX_UNROLL];
+                ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
+
+                for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
+                    float * SS = S + i;
+
+                    for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
+                        if (SS[j] == -INFINITY) {
+                            SS[j] = 0.0f;
+                        } else {
+                            ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
+                            memcpy(&scvt[j], &s, sizeof(uint16_t));
+                            const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+                            sump[j] += val;
+                            SS[j] = val;
+                        }
+                    }
+                }
+
+                for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
+                    sum += sump[i];
+                }
+#endif
+            }
+
+            assert(sum > 0.0f);
+
+            sum = 1.0/sum;
+            ggml_vec_scale_f32(M, S, sum);
+
+#ifndef NDEBUG
+            for (int i = 0; i < M; ++i) {
+                assert(!isnan(S[i]));
+                assert(!isinf(S[i]));
+            }
+#endif
+        }
+
+        ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
+
+        for (int i = 0; i < M; i++) {
+            S16[i] = GGML_FP32_TO_FP16(S[i]);
+        }
+
+        if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
+            for (int ic = 0; ic < nev1; ++ic) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
+
+                ggml_vec_dot_f16(nek1,
+                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
+                        (ggml_fp16_t *) ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                        S16);
+            }
+        } else {
+            for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
+
+                ggml_vec_dot_f16_unroll(nek1, nbv1,
+                        (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
+                        ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                        S16);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_flash_attn(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const bool masked,
+        struct ggml_tensor * dst) {
+    switch (q->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_flash_ff
+
+static void ggml_compute_forward_flash_ff_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,  // F16
+        const struct ggml_tensor * b0, // F16 fc_w
+        const struct ggml_tensor * b1, // F32 fc_b
+        const struct ggml_tensor * c0, // F16 proj_w
+        const struct ggml_tensor * c1, // F32 proj_b
+        struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int nea0 = a->ne[0];
+    const int nea1 = a->ne[1];
+    const int nea2 = a->ne[2];
+    const int nea3 = a->ne[3];
+
+    const int neb00 = b0->ne[0];
+    const int neb01 = b0->ne[1];
+    //const int neb02 = b0->ne[2];
+    //const int neb03 = b0->ne[3];
+
+    const int neb10 = b1->ne[0];
+    const int neb11 = b1->ne[1];
+    //const int neb12 = b1->ne[2];
+    //const int neb13 = b1->ne[3];
+
+    const int nec00 = c0->ne[0];
+    const int nec01 = c0->ne[1];
+    //const int nec02 = c0->ne[2];
+    //const int nec03 = c0->ne[3];
+
+    const int nec10 = c1->ne[0];
+    const int nec11 = c1->ne[1];
+    //const int nec12 = c1->ne[2];
+    //const int nec13 = c1->ne[3];
+
+    const int ne0 = dst->ne[0];
+    const int ne1 = dst->ne[1];
+    const int ne2 = dst->ne[2];
+    //const int ne3 = dst->ne[3];
+
+    const int nba0 = a->nb[0];
+    const int nba1 = a->nb[1];
+    const int nba2 = a->nb[2];
+    const int nba3 = a->nb[3];
+
+    const int nbb00 = b0->nb[0];
+    const int nbb01 = b0->nb[1];
+    const int nbb02 = b0->nb[2];
+    const int nbb03 = b0->nb[3];
+
+    const int nbb10 = b1->nb[0];
+    //const int nbb11 = b1->nb[1];
+    //const int nbb12 = b1->nb[2];
+    //const int nbb13 = b1->nb[3];
+
+    const int nbc00 = c0->nb[0];
+    const int nbc01 = c0->nb[1];
+    const int nbc02 = c0->nb[2];
+    const int nbc03 = c0->nb[3];
+
+    const int nbc10 = c1->nb[0];
+    //const int nbc11 = c1->nb[1];
+    //const int nbc12 = c1->nb[2];
+    //const int nbc13 = c1->nb[3];
+
+    const int nb0 = dst->nb[0];
+    const int nb1 = dst->nb[1];
+    const int nb2 = dst->nb[2];
+    const int nb3 = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int D = nea0;
+    //const int N = nea1;
+    const int M = neb01;
+
+    GGML_ASSERT(ne0 == nea0);
+    GGML_ASSERT(ne1 == nea1);
+    GGML_ASSERT(ne2 == nea2);
+
+    GGML_ASSERT(nba0  == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nbb10 == sizeof(float));
+    GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nbc10 == sizeof(float));
+
+    GGML_ASSERT(neb00 == D);
+    GGML_ASSERT(neb01 == M);
+    GGML_ASSERT(neb10 == M);
+    GGML_ASSERT(neb11 == 1);
+
+    GGML_ASSERT(nec00 == M);
+    GGML_ASSERT(nec01 == D);
+    GGML_ASSERT(nec10 == D);
+    GGML_ASSERT(nec11 == 1);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    if (params->type == GGML_TASK_INIT) {
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // parallelize by a rows using ggml_vec_dot_f32
+
+    // total rows in a
+    const int nr = nea1*nea2*nea3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // a indices
+        const int ia3 = ir/(nea2*nea1);
+        const int ia2 = (ir - ia3*nea2*nea1)/nea1;
+        const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
+
+        float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
+
+        for (int ic = 0; ic < neb01; ++ic) {
+            // b0 indices
+            const int ib03 = ia3;
+            const int ib02 = ia2;
+            const int ib01 = ic;
+
+            // S indices
+            const int i1 = ib01;
+
+            ggml_vec_dot_f16(nea0,
+                    S + i1,
+                    (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)),
+                    (ggml_fp16_t *) ((char *)  a->data + ( ia1*nba1  +  ia2*nba2  +  ia3*nba3)));
+        }
+
+        ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
+        //ggml_vec_gelu_f32(neb01, S, S);
+
+        ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
+
+        for (int i = 0; i < M; i++) {
+            S16[i] = GGML_FP32_TO_FP16(S[i]);
+        }
+
+        ggml_vec_gelu_f16(neb01, S16, S16);
+
+        {
+            // dst indices
+            const int i1 = ia1;
+            const int i2 = ia2;
+            const int i3 = ia3;
+
+            for (int ic = 0; ic < nec01; ++ic) {
+
+                ggml_vec_dot_f16(neb01,
+                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1   + i2*nb2   + i3*nb3)),
+                        (ggml_fp16_t *) ((char *) c0->data  + (         ic*nbc01 + i2*nbc02 + i3*nbc03)),
+                        S16);
+            }
+
+            ggml_vec_add_f32(nec01,
+                    (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
+                    (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
+                    (float *) c1->data);
+        }
+    }
+}
+
+static void ggml_compute_forward_flash_ff(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        const struct ggml_tensor * b0,
+        const struct ggml_tensor * b1,
+        const struct ggml_tensor * c0,
+        const struct ggml_tensor * c1,
+        struct ggml_tensor * dst) {
+    switch (b0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(false); // TODO
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+/////////////////////////////////
+
+static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+    GGML_ASSERT(params);
+
+    switch (tensor->op) {
+        case GGML_OP_DUP:
+            {
+                ggml_compute_forward_dup(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_ADD:
+            {
+                ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_SUB:
+            {
+                ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_MUL:
+            {
+                ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_DIV:
+            {
+                ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_SQR:
+            {
+                ggml_compute_forward_sqr(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_SQRT:
+            {
+                ggml_compute_forward_sqrt(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_SUM:
+            {
+                ggml_compute_forward_sum(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_MEAN:
+            {
+                ggml_compute_forward_mean(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_REPEAT:
+            {
+                ggml_compute_forward_repeat(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_ABS:
+            {
+                ggml_compute_forward_abs(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_SGN:
+            {
+                ggml_compute_forward_sgn(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_NEG:
+            {
+                ggml_compute_forward_neg(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_STEP:
+            {
+                ggml_compute_forward_step(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_RELU:
+            {
+                ggml_compute_forward_relu(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_GELU:
+            {
+                ggml_compute_forward_gelu(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_SILU:
+            {
+                ggml_compute_forward_silu(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_NORM:
+            {
+                ggml_compute_forward_norm(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_MUL_MAT:
+            {
+                ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_SCALE:
+            {
+                ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_CPY:
+            {
+                ggml_compute_forward_cpy(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_RESHAPE:
+            {
+                ggml_compute_forward_reshape(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_VIEW:
+            {
+                ggml_compute_forward_view(params, tensor->src0);
+            } break;
+        case GGML_OP_PERMUTE:
+            {
+                ggml_compute_forward_permute(params, tensor->src0);
+            } break;
+        case GGML_OP_TRANSPOSE:
+            {
+                ggml_compute_forward_transpose(params, tensor->src0);
+            } break;
+        case GGML_OP_GET_ROWS:
+            {
+                ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_DIAG_MASK_INF:
+            {
+                ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_SOFT_MAX:
+            {
+                ggml_compute_forward_soft_max(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_ROPE:
+            {
+                ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_CONV_1D_1S:
+            {
+                ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_CONV_1D_2S:
+            {
+                ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_FLASH_ATTN:
+            {
+                int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
+                GGML_ASSERT(t == 0 || t == 1);
+                bool masked = t != 0;
+                ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
+            } break;
+        case GGML_OP_FLASH_FF:
+            {
+                ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
+            } break;
+        case GGML_OP_NONE:
+            {
+                // nop
+            } break;
+        case GGML_OP_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
+    struct ggml_tensor * src0 = tensor->src0;
+    struct ggml_tensor * src1 = tensor->src1;
+
+    switch (tensor->op) {
+        case GGML_OP_DUP:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                }
+            } break;
+        case GGML_OP_ADD:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                }
+                if (src1->grad) {
+                    src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
+                }
+            } break;
+        case GGML_OP_SUB:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                }
+                if (src1->grad) {
+                    src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace);
+                }
+            } break;
+        case GGML_OP_MUL:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_mul(ctx, src1, tensor->grad),
+                                inplace);
+                }
+                if (src1->grad) {
+                    src1->grad =
+                        ggml_add_impl(ctx,
+                                src1->grad,
+                                ggml_mul(ctx, src0, tensor->grad),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_DIV:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_div(ctx, tensor->grad, src1),
+                                inplace);
+                }
+                if (src1->grad) {
+                    src1->grad =
+                        ggml_sub_impl(ctx,
+                                src1->grad,
+                                ggml_mul(ctx,
+                                    tensor->grad,
+                                    ggml_div(ctx, tensor, src1)),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_SQR:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_mul(ctx,
+                                    ggml_mul(ctx, src0, tensor->grad),
+                                    ggml_repeat(ctx, ggml_new_f32(ctx, 2.0f), src0)),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_SQRT:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_div(ctx,
+                                    ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
+                                    tensor),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_SUM:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_repeat(ctx, tensor->grad, src0->grad),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_MEAN:
+            {
+                GGML_ASSERT(false); // TODO: implement
+            } break;
+        case GGML_OP_REPEAT:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_sum(ctx, tensor->grad),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_ABS:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_mul(ctx,
+                                    ggml_sgn(ctx, src0),
+                                    tensor->grad),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_SGN:
+            {
+                if (src0->grad) {
+                    // noop
+                }
+            } break;
+        case GGML_OP_NEG:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
+                }
+            } break;
+        case GGML_OP_STEP:
+            {
+                if (src0->grad) {
+                    // noop
+                }
+            } break;
+        case GGML_OP_RELU:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_sub_impl(ctx,
+                            src0->grad,
+                            ggml_mul(ctx,
+                                ggml_step(ctx, src0),
+                                tensor->grad),
+                            inplace);
+                }
+            } break;
+        case GGML_OP_GELU:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_SILU:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_NORM:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_MUL_MAT:
+            {
+                if (src0->grad) {
+                    // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
+                    GGML_ASSERT(false);
+                }
+                if (src1->grad) {
+                    src1->grad =
+                        ggml_add_impl(ctx,
+                                src1->grad,
+                                // TODO: fix transpose, the node will break the graph connections
+                                ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_SCALE:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_CPY:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_RESHAPE:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_VIEW:
+            {
+                GGML_ASSERT(false); // not supported
+            } break;
+        case GGML_OP_PERMUTE:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_TRANSPOSE:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_GET_ROWS:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_DIAG_MASK_INF:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_SOFT_MAX:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_ROPE:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_CONV_1D_1S:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_CONV_1D_2S:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_FLASH_ATTN:
+            {
+                GGML_ASSERT(false); // not supported
+            } break;
+        case GGML_OP_FLASH_FF:
+            {
+                GGML_ASSERT(false); // not supported
+            } break;
+        case GGML_OP_NONE:
+            {
+                // nop
+            } break;
+        case GGML_OP_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
+    if (node->grad == NULL) {
+        // this usually happens when we generate intermediate nodes from constants in the backward pass
+        // it can also happen during forward pass, if the user performs computations with constants
+        if (node->op != GGML_OP_NONE) {
+            //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op);
+        }
+    }
+
+    // check if already visited
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        if (cgraph->nodes[i] == node) {
+            return;
+        }
+    }
+
+    for (int i = 0; i < cgraph->n_leafs; i++) {
+        if (cgraph->leafs[i] == node) {
+            return;
+        }
+    }
+
+    if (node->src0) {
+        ggml_visit_parents(cgraph, node->src0);
+    }
+
+    if (node->src1) {
+        ggml_visit_parents(cgraph, node->src1);
+    }
+
+    for (int i = 0; i < GGML_MAX_OPT; ++i) {
+        if (node->opt[i]) {
+            ggml_visit_parents(cgraph, node->opt[i]);
+        }
+    }
+
+    if (node->op == GGML_OP_NONE && node->grad == NULL) {
+        // reached a leaf node, not part of the gradient graph (e.g. a constant)
+        GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES);
+
+        cgraph->leafs[cgraph->n_leafs] = node;
+        cgraph->n_leafs++;
+    } else {
+        GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES);
+
+        cgraph->nodes[cgraph->n_nodes] = node;
+        cgraph->grads[cgraph->n_nodes] = node->grad;
+        cgraph->n_nodes++;
+    }
+}
+
+static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
+    if (!expand) {
+        cgraph->n_nodes = 0;
+        cgraph->n_leafs = 0;
+    }
+
+    const int n0 = cgraph->n_nodes;
+    UNUSED(n0);
+
+    ggml_visit_parents(cgraph, tensor);
+
+    const int n_new = cgraph->n_nodes - n0;
+    GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
+
+    if (n_new > 0) {
+        // the last added node should always be starting point
+        GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
+    }
+}
+
+void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
+    ggml_build_forward_impl(cgraph, tensor, true);
+}
+
+struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
+    struct ggml_cgraph result = {
+        /*.n_nodes      =*/ 0,
+        /*.n_leafs      =*/ 0,
+        /*.n_threads    =*/ 0,
+        /*.work_size    =*/ 0,
+        /*.work         =*/ NULL,
+        /*.nodes        =*/ { NULL },
+        /*.grads        =*/ { NULL },
+        /*.leafs        =*/ { NULL },
+        /*.perf_runs    =*/ 0,
+        /*.perf_cycles  =*/ 0,
+        /*.perf_time_us =*/ 0,
+    };
+
+    ggml_build_forward_impl(&result, tensor, false);
+
+    return result;
+}
+
+struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
+    struct ggml_cgraph result = *gf;
+
+    GGML_ASSERT(gf->n_nodes > 0);
+
+    // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
+    if (keep) {
+        for (int i = 0; i < gf->n_nodes; i++) {
+            struct ggml_tensor * node = gf->nodes[i];
+
+            if (node->grad) {
+                node->grad = ggml_dup_tensor(ctx, node);
+                gf->grads[i] = node->grad;
+            }
+        }
+    }
+
+    for (int i = gf->n_nodes - 1; i >= 0; i--) {
+        struct ggml_tensor * node = gf->nodes[i];
+
+        // because we detached the grad nodes from the original graph, we can afford inplace operations
+        if (node->grad) {
+            ggml_compute_backward(ctx, node, keep);
+        }
+    }
+
+    for (int i = gf->n_nodes - 1; i >= 0; i--) {
+        struct ggml_tensor * node = gf->nodes[i];
+
+        if (node->is_param) {
+            GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
+            ggml_build_forward_impl(&result, node->grad, true);
+        }
+    }
+
+    return result;
+}
+
+//
+// thread data
+//
+// synchronization is done via busy loops
+// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops
+//
+
+#ifdef __APPLE__
+
+//#include <os/lock.h>
+//
+//typedef os_unfair_lock ggml_lock_t;
+//
+//#define ggml_lock_init(x)    UNUSED(x)
+//#define ggml_lock_destroy(x) UNUSED(x)
+//#define ggml_lock_lock       os_unfair_lock_lock
+//#define ggml_lock_unlock     os_unfair_lock_unlock
+//
+//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT
+
+typedef int ggml_lock_t;
+
+#define ggml_lock_init(x)    UNUSED(x)
+#define ggml_lock_destroy(x) UNUSED(x)
+#define ggml_lock_lock(x)    UNUSED(x)
+#define ggml_lock_unlock(x)  UNUSED(x)
+
+#define GGML_LOCK_INITIALIZER 0
+
+typedef pthread_t ggml_thread_t;
+
+#define ggml_thread_create pthread_create
+#define ggml_thread_join   pthread_join
+
+#else
+
+//typedef pthread_spinlock_t ggml_lock_t;
+
+//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE)
+//#define ggml_lock_destroy pthread_spin_destroy
+//#define ggml_lock_lock    pthread_spin_lock
+//#define ggml_lock_unlock  pthread_spin_unlock
+
+typedef int ggml_lock_t;
+
+#define ggml_lock_init(x)    UNUSED(x)
+#define ggml_lock_destroy(x) UNUSED(x)
+#define ggml_lock_lock(x)    UNUSED(x)
+#define ggml_lock_unlock(x)  UNUSED(x)
+
+#define GGML_LOCK_INITIALIZER 0
+
+typedef pthread_t ggml_thread_t;
+
+#define ggml_thread_create pthread_create
+#define ggml_thread_join   pthread_join
+
+#endif
+
+struct ggml_compute_state_shared {
+    ggml_lock_t spin;
+
+    int n_threads;
+
+    // synchronization primitives
+    atomic_int  n_ready;
+    atomic_bool has_work;
+    atomic_bool stop; // stop all threads
+};
+
+struct ggml_compute_state {
+    ggml_thread_t thrd;
+
+    struct ggml_compute_params params;
+    struct ggml_tensor * node;
+
+    struct ggml_compute_state_shared * shared;
+};
+
+static thread_ret_t ggml_graph_compute_thread(void * data) {
+    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+
+    const int n_threads = state->shared->n_threads;
+
+    while (true) {
+        if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) {
+            atomic_store(&state->shared->has_work, false);
+        } else {
+            while (atomic_load(&state->shared->has_work)) {
+                if (atomic_load(&state->shared->stop)) {
+                    return 0;
+                }
+                ggml_lock_lock  (&state->shared->spin);
+                ggml_lock_unlock(&state->shared->spin);
+            }
+        }
+
+        atomic_fetch_sub(&state->shared->n_ready, 1);
+
+        // wait for work
+        while (!atomic_load(&state->shared->has_work)) {
+            if (atomic_load(&state->shared->stop)) {
+                return 0;
+            }
+            ggml_lock_lock  (&state->shared->spin);
+            ggml_lock_unlock(&state->shared->spin);
+        }
+
+        // check if we should stop
+        if (atomic_load(&state->shared->stop)) {
+            break;
+        }
+
+        if (state->node) {
+            if (state->params.ith < state->params.nth) {
+                ggml_compute_forward(&state->params, state->node);
+            }
+
+            state->node = NULL;
+        } else {
+            break;
+        }
+    }
+
+    return 0;
+}
+
+void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
+    if (cgraph->n_threads <= 0) {
+        cgraph->n_threads = 8;
+    }
+
+    const int n_threads = cgraph->n_threads;
+
+    struct ggml_compute_state_shared state_shared = {
+        /*.spin      =*/ GGML_LOCK_INITIALIZER,
+        /*.n_threads =*/ n_threads,
+        /*.n_ready   =*/ 0,
+        /*.has_work  =*/ false,
+        /*.stop      =*/ false,
+    };
+    struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL;
+
+    // create thread pool
+    if (n_threads > 1) {
+        ggml_lock_init(&state_shared.spin);
+
+        atomic_store(&state_shared.has_work, true);
+
+        for (int j = 0; j < n_threads - 1; j++) {
+            workers[j] = (struct ggml_compute_state) {
+                .thrd   = 0,
+                .params = {
+                    .type  = GGML_TASK_COMPUTE,
+                    .ith   = j + 1,
+                    .nth   = n_threads,
+                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+                    .wdata = cgraph->work ? cgraph->work->data : NULL,
+                },
+                .node   = NULL,
+                .shared = &state_shared,
+            };
+
+            int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
+            GGML_ASSERT(rc == 0);
+            UNUSED(rc);
+        }
+    }
+
+    // initialize tasks + work buffer
+    {
+        size_t work_size = 0;
+
+        // thread scheduling for the different operations
+        for (int i = 0; i < cgraph->n_nodes; i++) {
+            struct ggml_tensor * node = cgraph->nodes[i];
+
+            switch (node->op) {
+                case GGML_OP_DUP:
+                    {
+                        node->n_tasks = 1;
+                    } break;
+                case GGML_OP_ADD:
+                    {
+                        node->n_tasks = n_threads;
+                    } break;
+                case GGML_OP_SUB:
+                case GGML_OP_MUL:
+                case GGML_OP_DIV:
+                case GGML_OP_SQR:
+                case GGML_OP_SQRT:
+                case GGML_OP_SUM:
+                case GGML_OP_MEAN:
+                case GGML_OP_REPEAT:
+                case GGML_OP_ABS:
+                case GGML_OP_SGN:
+                case GGML_OP_NEG:
+                case GGML_OP_STEP:
+                case GGML_OP_RELU:
+                    {
+                        node->n_tasks = 1;
+                    } break;
+                case GGML_OP_GELU:
+                    {
+                        node->n_tasks = n_threads;
+                    } break;
+                case GGML_OP_SILU:
+                    {
+                        node->n_tasks = n_threads;
+                    } break;
+                case GGML_OP_NORM:
+                    {
+                        node->n_tasks = n_threads;
+                    } break;
+                case GGML_OP_MUL_MAT:
+                    {
+                        node->n_tasks = n_threads;
+
+                        // TODO: use different scheduling for different matrix sizes
+                        //const int nr0 = ggml_nrows(node->src0);
+                        //const int nr1 = ggml_nrows(node->src1);
+
+                        //node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
+                        //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
+
+                        size_t cur = 0;
+
+                        // TODO: better way to determine if the matrix is transposed
+                        if (node->src0->nb[1] < node->src0->nb[0]) {
+                            cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1)
+                                                                   // TODO: overestimated by factor of x2 for FP16
+                        } else {
+                            if (node->src0->type == GGML_TYPE_F16 &&
+                                node->src1->type == GGML_TYPE_F32) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                                if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                    node->n_tasks = 1; // TODO: this actually is doing nothing
+                                                       //       the threads are still spinning
+                                    cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+                                    //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
+                                    //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
+                                    //printf("cur = %zu\n", cur);
+                                } else {
+                                    cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
+                                }
+#else
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
+#endif
+                            } else if (node->src0->type == GGML_TYPE_F32 &&
+                                       node->src1->type == GGML_TYPE_F32) {
+                                cur = 0;
+                            } else if (node->src0->type == GGML_TYPE_Q4_0 &&
+                                       node->src1->type == GGML_TYPE_F32) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                                if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                    node->n_tasks = 1;
+                                    cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+                                } else {
+                                    cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
+                                }
+#else
+                                cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
+#endif
+                            } else if (node->src0->type == GGML_TYPE_Q4_1 &&
+                                       node->src1->type == GGML_TYPE_F32) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                                if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                    node->n_tasks = 1;
+                                    cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+                                } else {
+                                    cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
+                                }
+#else
+                                cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
+#endif
+                            } else {
+                                GGML_ASSERT(false);
+                            }
+                        }
+
+                        work_size = MAX(work_size, cur);
+                    } break;
+                case GGML_OP_SCALE:
+                    {
+                        node->n_tasks = n_threads;
+                    } break;
+                case GGML_OP_CPY:
+                case GGML_OP_RESHAPE:
+                case GGML_OP_VIEW:
+                case GGML_OP_PERMUTE:
+                case GGML_OP_TRANSPOSE:
+                case GGML_OP_GET_ROWS:
+                case GGML_OP_DIAG_MASK_INF:
+                    {
+                        node->n_tasks = 1;
+                    } break;
+                case GGML_OP_SOFT_MAX:
+                    {
+                        node->n_tasks = n_threads;
+                    } break;
+                case GGML_OP_ROPE:
+                    {
+                        node->n_tasks = 1;
+                    } break;
+                case GGML_OP_CONV_1D_1S:
+                case GGML_OP_CONV_1D_2S:
+                    {
+                        node->n_tasks = n_threads;
+
+                        GGML_ASSERT(node->src0->ne[3] == 1);
+                        GGML_ASSERT(node->src1->ne[2] == 1);
+                        GGML_ASSERT(node->src1->ne[3] == 1);
+
+                        size_t cur = 0;
+                        const int nk = node->src0->ne[0];
+
+                        if (node->src0->type == GGML_TYPE_F16 &&
+                            node->src1->type == GGML_TYPE_F32) {
+                            cur = sizeof(ggml_fp16_t)*(
+                                    nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
+                                    ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+                                    );
+                        } else if (node->src0->type == GGML_TYPE_F32 &&
+                                   node->src1->type == GGML_TYPE_F32) {
+                            cur = sizeof(float)*(
+                                    nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
+                                    ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+                                    );
+                        } else {
+                            GGML_ASSERT(false);
+                        }
+
+                        work_size = MAX(work_size, cur);
+                    } break;
+                case GGML_OP_FLASH_ATTN:
+                    {
+                        node->n_tasks = n_threads;
+
+                        size_t cur = 0;
+
+                        const int ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+
+                        if (node->src1->type == GGML_TYPE_F32) {
+                            cur  = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
+                        }
+
+                        if (node->src1->type == GGML_TYPE_F16) {
+                            cur  = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
+                        }
+
+                        work_size = MAX(work_size, cur);
+                    } break;
+                case GGML_OP_FLASH_FF:
+                    {
+                        node->n_tasks = n_threads;
+
+                        size_t cur = 0;
+
+                        if (node->src1->type == GGML_TYPE_F32) {
+                            cur  = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
+                        }
+
+                        if (node->src1->type == GGML_TYPE_F16) {
+                            cur  = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
+                        }
+
+                        work_size = MAX(work_size, cur);
+                    } break;
+                case GGML_OP_NONE:
+                    {
+                        node->n_tasks = 1;
+                    } break;
+                case GGML_OP_COUNT:
+                    {
+                        GGML_ASSERT(false);
+                    } break;
+            }
+        }
+
+        if (cgraph->work != NULL && work_size > cgraph->work_size) {
+            GGML_ASSERT(false); // TODO: better handling
+        }
+
+        if (work_size > 0 && cgraph->work == NULL) {
+            cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
+
+            GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
+            cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
+        }
+    }
+
+    const int64_t perf_start_cycles  = ggml_perf_cycles();
+    const int64_t perf_start_time_us = ggml_perf_time_us();
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes);
+
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        // TODO: this could be used to avoid unnecessary computations, but it needs to be improved
+        //if (node->grad == NULL && node->perf_runs > 0) {
+        //    continue;
+        //}
+
+        const int64_t perf_node_start_cycles  = ggml_perf_cycles();
+        const int64_t perf_node_start_time_us = ggml_perf_time_us();
+
+        // INIT
+        struct ggml_compute_params params = {
+            /*.type  =*/ GGML_TASK_INIT,
+            /*.ith   =*/ 0,
+            /*.nth   =*/ node->n_tasks,
+            /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+            /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+        };
+
+        ggml_compute_forward(&params, node);
+
+        // COMPUTE
+        if (node->n_tasks > 1) {
+            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+                atomic_store(&state_shared.has_work, false);
+            }
+
+            while (atomic_load(&state_shared.has_work)) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            // launch thread pool
+            for (int j = 0; j < n_threads - 1; j++) {
+                workers[j].params = (struct ggml_compute_params) {
+                    .type  = GGML_TASK_COMPUTE,
+                    .ith   = j + 1,
+                    .nth   = node->n_tasks,
+                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+                    .wdata = cgraph->work ? cgraph->work->data : NULL,
+                };
+                workers[j].node = node;
+            }
+
+            atomic_fetch_sub(&state_shared.n_ready, 1);
+
+            while (atomic_load(&state_shared.n_ready) > 0) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            atomic_store(&state_shared.has_work, true);
+        }
+
+        params.type = GGML_TASK_COMPUTE;
+        ggml_compute_forward(&params, node);
+
+        // wait for thread pool
+        if (node->n_tasks > 1) {
+            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+                atomic_store(&state_shared.has_work, false);
+            }
+
+            while (atomic_load(&state_shared.has_work)) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            atomic_fetch_sub(&state_shared.n_ready, 1);
+
+            while (atomic_load(&state_shared.n_ready) != 0) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+        }
+
+        // FINALIZE
+        if (node->n_tasks > 1) {
+            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+                atomic_store(&state_shared.has_work, false);
+            }
+
+            while (atomic_load(&state_shared.has_work)) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            // launch thread pool
+            for (int j = 0; j < n_threads - 1; j++) {
+                workers[j].params = (struct ggml_compute_params) {
+                    .type  = GGML_TASK_FINALIZE,
+                    .ith   = j + 1,
+                    .nth   = node->n_tasks,
+                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+                    .wdata = cgraph->work ? cgraph->work->data : NULL,
+                };
+                workers[j].node = node;
+            }
+
+            atomic_fetch_sub(&state_shared.n_ready, 1);
+
+            while (atomic_load(&state_shared.n_ready) > 0) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            atomic_store(&state_shared.has_work, true);
+        }
+
+        params.type = GGML_TASK_FINALIZE;
+        ggml_compute_forward(&params, node);
+
+        // wait for thread pool
+        if (node->n_tasks > 1) {
+            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+                atomic_store(&state_shared.has_work, false);
+            }
+
+            while (atomic_load(&state_shared.has_work)) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            atomic_fetch_sub(&state_shared.n_ready, 1);
+
+            while (atomic_load(&state_shared.n_ready) != 0) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+        }
+
+        // performance stats (node)
+        {
+            int64_t perf_cycles_cur  = ggml_perf_cycles()  - perf_node_start_cycles;
+            int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us;
+
+            node->perf_runs++;
+            node->perf_cycles  += perf_cycles_cur;
+            node->perf_time_us += perf_time_us_cur;
+        }
+    }
+
+    // join thread pool
+    if (n_threads > 1) {
+        atomic_store(&state_shared.stop, true);
+        atomic_store(&state_shared.has_work, true);
+
+        for (int j = 0; j < n_threads - 1; j++) {
+            int rc = ggml_thread_join(workers[j].thrd, NULL);
+            GGML_ASSERT(rc == 0);
+            UNUSED(rc);
+        }
+
+        ggml_lock_destroy(&state_shared.spin);
+    }
+
+    // performance stats (graph)
+    {
+        int64_t perf_cycles_cur  = ggml_perf_cycles()  - perf_start_cycles;
+        int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us;
+
+        cgraph->perf_runs++;
+        cgraph->perf_cycles  += perf_cycles_cur;
+        cgraph->perf_time_us += perf_time_us_cur;
+
+        GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n",
+                __func__, cgraph->perf_runs,
+                (double) perf_cycles_cur      / (double) ggml_cycles_per_ms(),
+                (double) cgraph->perf_cycles  / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs,
+                (double) perf_time_us_cur     / 1000.0,
+                (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
+    }
+}
+
+void ggml_graph_reset(struct ggml_cgraph * cgraph) {
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * grad = cgraph->grads[i];
+
+        if (grad) {
+            ggml_set_zero(grad);
+        }
+    }
+}
+
+void ggml_graph_print(const struct ggml_cgraph * cgraph) {
+    int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0};
+
+    GGML_PRINT("=== GRAPH ===\n");
+
+    GGML_PRINT_DEBUG("n_threads       = %d\n",       cgraph->n_threads);
+    GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
+
+    GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        perf_total_per_op_us[node->op] += node->perf_time_us;
+
+        GGML_PRINT(" - %3d: [ %6d, %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
+                i,
+                node->ne[0], node->ne[1], node->ne[2],
+                GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
+                (double) node->perf_cycles  / (double) ggml_cycles_per_ms(),
+                (double) node->perf_cycles  / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
+                (double) node->perf_time_us / 1000.0,
+                (double) node->perf_time_us / 1000.0 / node->perf_runs);
+    }
+
+    GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs);
+    for (int i = 0; i < cgraph->n_leafs; i++) {
+        struct ggml_tensor * node = cgraph->leafs[i];
+
+        GGML_PRINT(" - %3d: [ %6d, %6d] %8s\n",
+                i,
+                node->ne[0], node->ne[1],
+                GGML_OP_LABEL[node->op]);
+    }
+
+    for (int i = 0; i < GGML_OP_COUNT; i++) {
+        GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
+    }
+
+    GGML_PRINT("========================================\n");
+}
+
+// check if node is part of the graph
+static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+    if (cgraph == NULL) {
+        return true;
+    }
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        if (cgraph->nodes[i] == node) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * parent = cgraph->nodes[i];
+
+        if (parent->grad == node) {
+            return parent;
+        }
+    }
+
+    return NULL;
+}
+
+void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
+    char color[16];
+
+    FILE * fp = fopen(filename, "w");
+    GGML_ASSERT(fp);
+
+    fprintf(fp, "digraph G {\n");
+    fprintf(fp, "  newrank = true;\n");
+    fprintf(fp, "  rankdir = LR;\n");
+
+    for (int i = 0; i < gb->n_nodes; i++) {
+        struct ggml_tensor * node = gb->nodes[i];
+
+        if (ggml_graph_get_parent(gb, node) != NULL) {
+            continue;
+        }
+
+        if (node->is_param) {
+            snprintf(color, sizeof(color), "yellow");
+        } else if (node->grad) {
+            if (ggml_graph_find(gf, node)) {
+                snprintf(color, sizeof(color), "green");
+            } else {
+                snprintf(color, sizeof(color), "lightblue");
+            }
+        } else {
+            snprintf(color, sizeof(color), "white");
+        }
+
+        fprintf(fp, "  \"%p\" [ \
+style = filled; fillcolor = %s; shape = record; \
+label=\"%d [%d, %d] | <x>%s",
+                (void *) node, color,
+                i, node->ne[0], node->ne[1],
+                GGML_OP_SYMBOL[node->op]);
+
+        if (node->grad) {
+            fprintf(fp, " | <g>%s\"; ]\n", GGML_OP_SYMBOL[node->grad->op]);
+        } else {
+            fprintf(fp, "\"; ]\n");
+        }
+    }
+
+    for (int i = 0; i < gb->n_leafs; i++) {
+        struct ggml_tensor * node = gb->leafs[i];
+
+        snprintf(color, sizeof(color), "pink");
+
+        if (ggml_nelements(node) == 1) {
+            fprintf(fp, "  \"%p\" [ \
+style = filled; fillcolor = %s; shape = record; \
+label=\"<x>%.1e\"; ]\n",
+                    (void *) node, color, ggml_get_f32_1d(node, 0));
+        } else {
+            fprintf(fp, "  \"%p\" [ \
+style = filled; fillcolor = %s; shape = record; \
+label=\"<x>CONST %d [%d, %d]\"; ]\n",
+                    (void *) node, color,
+                    i, node->ne[0], node->ne[1]);
+        }
+    }
+
+    for (int i = 0; i < gb->n_nodes; i++) {
+        struct ggml_tensor * node = gb->nodes[i];
+
+        struct ggml_tensor * parent = ggml_graph_get_parent(gb, node);
+
+        if (node->src0) {
+            struct ggml_tensor * parent0 = ggml_graph_get_parent(gb, node->src0);
+
+            fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"x\"; ]\n",
+                    parent0 ? (void *) parent0 : (void *) node->src0,
+                    parent0 ? "g" : "x",
+                    parent ? (void *) parent : (void *) node,
+                    parent ? "g" : "x",
+                    parent ? "empty" : "vee",
+                    parent ? "dashed" : "solid");
+        }
+
+        if (node->src1) {
+            struct ggml_tensor * parent1 = ggml_graph_get_parent(gb, node->src1);
+
+            fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"y\"; ]\n",
+                    parent1 ? (void *) parent1 : (void *) node->src1,
+                    parent1 ? "g" : "x",
+                    parent ? (void *) parent : (void *) node,
+                    parent ? "g" : "x",
+                    parent ? "empty" : "vee",
+                    parent ? "dashed" : "solid");
+        }
+    }
+
+    for (int i = 0; i < gb->n_leafs; i++) {
+        struct ggml_tensor * node = gb->leafs[i];
+
+        if (node->src0) {
+            fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ label = \"x\"; ]\n",
+                    (void *) node->src0, "x",
+                    (void *) node, "x");
+        }
+
+        if (node->src1) {
+            fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ label = \"y\"; ]\n",
+                    (void *) node->src1, "x",
+                    (void *) node, "x");
+        }
+    }
+
+    fprintf(fp, "}\n");
+
+    fclose(fp);
+
+    GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) {
+    int i = 0;
+    for (int p = 0; p < np; ++p) {
+        const int ne = ggml_nelements(ps[p]) ;
+        // TODO: add function to set tensor from array
+        for (int j = 0; j < ne; ++j) {
+            ggml_set_f32_1d(ps[p], j, x[i++]);
+        }
+    }
+}
+
+static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) {
+    int i = 0;
+    for (int p = 0; p < np; ++p) {
+        const int ne = ggml_nelements(ps[p]) ;
+        // TODO: add function to get all elements at once
+        for (int j = 0; j < ne; ++j) {
+            x[i++] = ggml_get_f32_1d(ps[p], j);
+        }
+    }
+}
+
+static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
+    int i = 0;
+    for (int p = 0; p < np; ++p) {
+        const int ne = ggml_nelements(ps[p]) ;
+        // TODO: add function to get all elements at once
+        for (int j = 0; j < ne; ++j) {
+            g[i++] = ggml_get_f32_1d(ps[p]->grad, j);
+        }
+    }
+}
+
+//
+// ADAM
+//
+//   ref: https://arxiv.org/pdf/1412.6980.pdf
+//
+
+static enum ggml_opt_result ggml_opt_adam(
+        struct ggml_context * ctx,
+        struct ggml_opt_params params,
+        struct ggml_tensor * f,
+        struct ggml_cgraph * gf,
+        struct ggml_cgraph * gb) {
+    GGML_ASSERT(ggml_is_scalar(f));
+
+    gf->n_threads = params.n_threads;
+    gb->n_threads = params.n_threads;
+
+    // these will store the parameters we want to optimize
+    struct ggml_tensor * ps[GGML_MAX_PARAMS];
+
+    int np = 0;
+    int nx = 0;
+    for (int i = 0; i < gf->n_nodes; ++i) {
+        if (gf->nodes[i]->is_param) {
+            GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
+
+            GGML_ASSERT(np < GGML_MAX_PARAMS);
+
+            ps[np++] = gf->nodes[i];
+            nx += ggml_nelements(gf->nodes[i]);
+        }
+    }
+
+    // constants
+    const float alpha = params.adam.alpha;
+    const float beta1 = params.adam.beta1;
+    const float beta2 = params.adam.beta2;
+    const float eps   = params.adam.eps;
+
+    float * x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters
+    float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient
+    float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared
+    float * m  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment
+    float * v  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment
+    float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat
+    float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat
+
+    float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
+
+    // initialize
+    ggml_vec_set_f32(nx, m, 0.0f);
+    ggml_vec_set_f32(nx, v, 0.0f);
+
+    // update view
+    ggml_opt_get_params(np, ps, x);
+
+    // compute the function value
+    ggml_graph_reset  (gf);
+    ggml_set_f32      (f->grad, 1.0f);
+    ggml_graph_compute(ctx, gb);
+
+    float fx_prev = ggml_get_f32_1d(f, 0);
+    if (pf) {
+        pf[0] = fx_prev;
+    }
+
+    int n_no_improvement = 0;
+    float fx_best = fx_prev;
+
+    // run the optimizer
+    for (int t = 0; t < params.adam.n_iter; ++t) {
+        GGML_PRINT_DEBUG  ("=== iter %d ===\n", t);
+
+        GGML_PRINT_DEBUG  ("f      = %10.6f\n", ggml_get_f32_1d(f, 0));
+        GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0));
+        GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0));
+
+        for (int i = 0; i < np; ++i) {
+            GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i,
+                    ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0));
+        }
+
+        const int64_t t_start_wall = ggml_time_us();
+        const int64_t t_start_cpu = ggml_cycles();
+        UNUSED(t_start_wall);
+        UNUSED(t_start_cpu);
+
+        {
+            // update the gradient
+            ggml_opt_get_grad(np, ps, g1);
+
+            // m_t = beta1*m_t-1 + (1 - beta1)*g_t
+            ggml_vec_scale_f32(nx, m, beta1);
+            ggml_vec_mad_f32  (nx, m, g1, 1.0f - beta1);
+
+            // g2 = g1^2
+            ggml_vec_sqr_f32  (nx, g2, g1);
+
+            // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2
+            ggml_vec_scale_f32(nx, v, beta2);
+            ggml_vec_mad_f32  (nx, v, g2, 1.0f - beta2);
+
+            // m^hat = m_t / (1 - beta1^t)
+            // v^hat = v_t / (1 - beta2^t)
+            // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
+            ggml_vec_cpy_f32  (nx, mh, m);
+            ggml_vec_cpy_f32  (nx, vh, v);
+
+            ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1)));
+            ggml_vec_scale_f32(nx, vh,  1.0f/(1.0f - powf(beta2, t + 1)));
+
+            ggml_vec_sqrt_f32 (nx, vh, vh);
+            ggml_vec_acc1_f32 (nx, vh, eps);
+
+            ggml_vec_div_f32  (nx, mh, mh, vh);
+            ggml_vec_sub_f32  (nx, x,  x,  mh);
+
+            // update the parameters
+            ggml_opt_set_params(np, ps, x);
+        }
+
+        ggml_graph_reset  (gf);
+        ggml_set_f32      (f->grad, 1.0f);
+        ggml_graph_compute(ctx, gb);
+
+        const float fx = ggml_get_f32_1d(f, 0);
+
+        // check convergence
+        if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
+            GGML_PRINT_DEBUG("converged\n");
+
+            return GGML_OPT_OK;
+        }
+
+        // delta-based convergence test
+        if (pf != NULL) {
+            // need at least params.past iterations to start checking for convergence
+            if (params.past <= t) {
+                const float rate = (pf[t%params.past] - fx)/fx;
+
+                if (fabs(rate) < params.delta) {
+                    return GGML_OPT_OK;
+                }
+            }
+
+            pf[t%params.past] = fx;
+        }
+
+        // check for improvement
+        if (params.max_no_improvement > 0) {
+            if (fx_best > fx) {
+                fx_best = fx;
+                n_no_improvement = 0;
+            } else {
+                ++n_no_improvement;
+
+                if (n_no_improvement >= params.max_no_improvement) {
+                    return GGML_OPT_OK;
+                }
+            }
+        }
+
+        fx_prev = fx;
+
+        {
+            const int64_t t_end_cpu = ggml_cycles();
+            GGML_PRINT_DEBUG("time iter:      %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC);
+            UNUSED(t_end_cpu);
+
+            const int64_t t_end_wall = ggml_time_us();
+            GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6);
+            UNUSED(t_end_wall);
+        }
+    }
+
+    return GGML_OPT_DID_NOT_CONVERGE;
+}
+
+//
+// L-BFGS
+//
+// the L-BFGS implementation below is based on the following implementation:
+//
+//   https://github.com/chokkan/liblbfgs
+//
+
+struct ggml_lbfgs_iteration_data {
+    float alpha;
+    float ys;
+    float * s;
+    float * y;
+};
+
+static enum ggml_opt_result linesearch_backtracking(
+        struct ggml_context * ctx,
+        const struct ggml_opt_params * params,
+        int nx,
+        float * x,
+        float * fx,
+        float * g,
+        float * d,
+        float * step,
+        const float * xp,
+        struct ggml_tensor * f,
+        struct ggml_cgraph * gf,
+        struct ggml_cgraph * gb,
+        const int np,
+        struct ggml_tensor * ps[]) {
+    int count = 0;
+
+    float width  = 0.0f;
+    float dg     = 0.0f;
+    float finit  = 0.0f;
+    float dginit = 0.0f;
+    float dgtest = 0.0f;
+
+    const float dec = 0.5f;
+    const float inc = 2.1f;
+
+    if (*step <= 0.) {
+        return GGML_LINESEARCH_INVALID_PARAMETERS;
+    }
+
+    // compute the initial gradient in the search direction
+    ggml_vec_dot_f32(nx, &dginit, g, d);
+
+    // make sure that d points to a descent direction
+    if (0 < dginit) {
+        return GGML_LINESEARCH_FAIL;
+    }
+
+    // initialize local variables
+    finit = *fx;
+    dgtest = params->lbfgs.ftol*dginit;
+
+    while (true) {
+        ggml_vec_cpy_f32(nx, x, xp);
+        ggml_vec_mad_f32(nx, x, d, *step);
+
+        // evaluate the function and gradient values
+        {
+            ggml_opt_set_params(np, ps, x);
+
+            ggml_graph_reset  (gf);
+            ggml_set_f32      (f->grad, 1.0f);
+            ggml_graph_compute(ctx, gb);
+
+            ggml_opt_get_grad(np, ps, g);
+
+            *fx = ggml_get_f32_1d(f, 0);
+        }
+
+        ++count;
+
+        if (*fx > finit + (*step)*dgtest) {
+            width = dec;
+        } else {
+            // Armijo condition is satisfied
+            if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) {
+                return count;
+            }
+
+            ggml_vec_dot_f32(nx, &dg, g, d);
+
+            // check the Wolfe condition
+            if (dg < params->lbfgs.wolfe * dginit) {
+                width = inc;
+            } else {
+                if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) {
+                    // regular Wolfe conditions
+                    return count;
+                }
+
+                if(dg > -params->lbfgs.wolfe*dginit) {
+                    width = dec;
+                } else {
+                    // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE)
+                    return count;
+                }
+                return count;
+            }
+        }
+
+        if (*step < params->lbfgs.min_step) {
+            return GGML_LINESEARCH_MINIMUM_STEP;
+        }
+        if (*step > params->lbfgs.max_step) {
+            return GGML_LINESEARCH_MAXIMUM_STEP;
+        }
+        if (params->lbfgs.max_linesearch <= count) {
+            return GGML_LINESEARCH_MAXIMUM_ITERATIONS;
+        }
+
+        (*step) *= width;
+    }
+
+    return GGML_LINESEARCH_FAIL;
+}
+
+static enum ggml_opt_result ggml_opt_lbfgs(
+        struct ggml_context * ctx,
+        struct ggml_opt_params params,
+        struct ggml_tensor * f,
+        struct ggml_cgraph * gf,
+        struct ggml_cgraph * gb) {
+    if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE ||
+        params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
+        if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) {
+            return GGML_OPT_INVALID_WOLFE;
+        }
+    }
+
+    gf->n_threads = params.n_threads;
+    gb->n_threads = params.n_threads;
+
+    const int m = params.lbfgs.m;
+
+    // these will store the parameters we want to optimize
+    struct ggml_tensor * ps[GGML_MAX_PARAMS];
+
+    int np = 0;
+    int nx = 0;
+    for (int i = 0; i < gf->n_nodes; ++i) {
+        if (gf->nodes[i]->is_param) {
+            GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
+
+            GGML_ASSERT(np < GGML_MAX_PARAMS);
+
+            ps[np++] = gf->nodes[i];
+            nx += ggml_nelements(gf->nodes[i]);
+        }
+    }
+
+    float * x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters
+    float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters
+    float * g  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient
+    float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient
+    float * d  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction
+
+    float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
+
+    float fx    = 0.0f; // cost function value
+    float xnorm = 0.0f; // ||x||
+    float gnorm = 0.0f; // ||g||
+    float step  = 0.0f;
+
+    // initialize x from the graph nodes
+    ggml_opt_get_params(np, ps, x);
+
+    // the L-BFGS memory
+    struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m);
+
+    for (int i = 0; i < m; ++i) {
+        lm[i].alpha = 0.0f;
+        lm[i].ys    = 0.0f;
+        lm[i].s     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
+        lm[i].y     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
+    }
+
+    // evaluate the function value and its gradient
+    {
+        ggml_opt_set_params(np, ps, x);
+
+        ggml_graph_reset  (gf);
+        ggml_set_f32      (f->grad, 1.0f);
+        ggml_graph_compute(ctx, gb);
+
+        ggml_opt_get_grad(np, ps, g);
+
+        fx = ggml_get_f32_1d(f, 0);
+    }
+
+    if (pf) {
+        pf[0] = fx;
+    }
+
+    float fx_best = fx;
+
+    // search direction = -gradient
+    ggml_vec_neg_f32(nx, d, g);
+
+    // ||x||, ||g||
+    ggml_vec_norm_f32(nx, &xnorm, x);
+    ggml_vec_norm_f32(nx, &gnorm, g);
+
+    if (xnorm < 1.0f) {
+        xnorm = 1.0f;
+    }
+
+    // already optimized
+    if (gnorm/xnorm <= params.lbfgs.eps) {
+        return GGML_OPT_OK;
+    }
+
+    // initial step
+    ggml_vec_norm_inv_f32(nx, &step, d);
+
+    int j                = 0;
+    int k                = 1;
+    int ls               = 0;
+    int end              = 0;
+    int bound            = 0;
+    int n_no_improvement = 0;
+
+    float ys   = 0.0f;
+    float yy   = 0.0f;
+    float beta = 0.0f;
+
+    while (true) {
+        // store the current position and gradient vectors
+        ggml_vec_cpy_f32(nx, xp, x);
+        ggml_vec_cpy_f32(nx, gp, g);
+
+        ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps);
+
+        if (ls < 0) {
+            // linesearch failed - go back to the previous point and return
+            ggml_vec_cpy_f32(nx, x, xp);
+            ggml_vec_cpy_f32(nx, g, gp);
+
+            return ls;
+        }
+
+        ggml_vec_norm_f32(nx, &xnorm, x);
+        ggml_vec_norm_f32(nx, &gnorm, g);
+
+        GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0));
+
+        if (xnorm < 1.0) {
+            xnorm = 1.0;
+        }
+        if (gnorm/xnorm <= params.lbfgs.eps) {
+            // converged
+            return GGML_OPT_OK;
+        }
+
+        // delta-based convergence test
+        if (pf != NULL) {
+            // need at least params.past iterations to start checking for convergence
+            if (params.past <= k) {
+                const float rate = (pf[k%params.past] - fx)/fx;
+
+                if (fabs(rate) < params.delta) {
+                    return GGML_OPT_OK;
+                }
+            }
+
+            pf[k%params.past] = fx;
+        }
+
+        // check for improvement
+        if (params.max_no_improvement > 0) {
+            if (fx < fx_best) {
+                fx_best = fx;
+                n_no_improvement = 0;
+            } else {
+                n_no_improvement++;
+
+                if (n_no_improvement >= params.max_no_improvement) {
+                    return GGML_OPT_OK;
+                }
+            }
+        }
+
+        if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) {
+            // reached the maximum number of iterations
+            return GGML_OPT_DID_NOT_CONVERGE;
+        }
+
+        // update vectors s and y:
+        //   s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
+        //   y_{k+1} = g_{k+1} - g_{k}.
+        //
+        ggml_vec_sub_f32(nx, lm[end].s, x, xp);
+        ggml_vec_sub_f32(nx, lm[end].y, g, gp);
+
+        // compute scalars ys and yy:
+        //     ys = y^t \cdot s    -> 1 / \rho.
+        //     yy = y^t \cdot y.
+        //
+        ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s);
+        ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y);
+
+        lm[end].ys = ys;
+
+        // find new search direction
+        //   ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
+
+        bound = (m <= k) ? m : k;
+        k++;
+        end = (end + 1)%m;
+
+        // initialize search direction with -g
+        ggml_vec_neg_f32(nx, d, g);
+
+        j = end;
+        for (int i = 0; i < bound; ++i) {
+            j = (j + m - 1) % m;
+            // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
+            ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d);
+            lm[j].alpha /= lm[j].ys;
+            // q_{i} = q_{i+1} - \alpha_{i} y_{i}
+            ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha);
+        }
+
+        ggml_vec_scale_f32(nx, d, ys/yy);
+
+        for (int i = 0; i < bound; ++i) {
+            // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
+            ggml_vec_dot_f32(nx, &beta, lm[j].y, d);
+            beta /= lm[j].ys;
+            // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
+            ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta);
+            j = (j + 1)%m;
+        }
+
+        step = 1.0;
+    }
+
+    return GGML_OPT_DID_NOT_CONVERGE;
+}
+
+struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
+    struct ggml_opt_params result;
+
+    switch (type) {
+        case GGML_OPT_ADAM:
+            {
+                result = (struct ggml_opt_params) {
+                    .type      = GGML_OPT_ADAM,
+                    .n_threads = 1,
+                    .past      = 0,
+                    .delta     = 1e-5f,
+
+                    .max_no_improvement = 100,
+
+                    .print_forward_graph  = true,
+                    .print_backward_graph = true,
+
+                    .adam = {
+                        .n_iter = 10000,
+                        .alpha  = 0.001f,
+                        .beta1  = 0.9f,
+                        .beta2  = 0.999f,
+                        .eps    = 1e-8f,
+                        .eps_f  = 1e-5f,
+                        .eps_g  = 1e-3f,
+                    },
+                };
+            } break;
+        case GGML_OPT_LBFGS:
+            {
+                result = (struct ggml_opt_params) {
+                    .type      = GGML_OPT_LBFGS,
+                    .n_threads = 1,
+                    .past      = 0,
+                    .delta     = 1e-5f,
+
+                    .max_no_improvement = 0,
+
+                    .print_forward_graph  = true,
+                    .print_backward_graph = true,
+
+                    .lbfgs = {
+                        .m              = 6,
+                        .n_iter         = 100,
+                        .max_linesearch = 20,
+
+                        .eps      = 1e-5f,
+                        .ftol     = 1e-4f,
+                        .wolfe    = 0.9f,
+                        .min_step = 1e-20f,
+                        .max_step = 1e+20f,
+
+                        .linesearch = GGML_LINESEARCH_DEFAULT,
+                    },
+                };
+            } break;
+    }
+
+    return result;
+}
+
+enum ggml_opt_result ggml_opt(
+        struct ggml_context * ctx,
+        struct ggml_opt_params params,
+        struct ggml_tensor * f) {
+    bool free_ctx = false;
+    if (ctx == NULL) {
+        struct ggml_init_params params_ctx = {
+            .mem_size   = 16*1024*1024,
+            .mem_buffer = NULL,
+        };
+
+        ctx = ggml_init(params_ctx);
+        if (ctx == NULL) {
+            return GGML_OPT_NO_CONTEXT;
+        }
+
+        free_ctx = true;
+    }
+
+    enum ggml_opt_result result = GGML_OPT_OK;
+
+    // build forward + backward compute graphs
+    struct ggml_cgraph gf = ggml_build_forward (f);
+    struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, false);
+
+    switch (params.type) {
+        case GGML_OPT_ADAM:
+            {
+                result = ggml_opt_adam(ctx, params, f, &gf, &gb);
+            } break;
+        case GGML_OPT_LBFGS:
+            {
+                result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb);
+            } break;
+    }
+
+    if (params.print_forward_graph) {
+        ggml_graph_print   (&gf);
+        ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot");
+    }
+
+    if (params.print_backward_graph) {
+        ggml_graph_print   (&gb);
+        ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
+    }
+
+    if (free_ctx) {
+        ggml_free(ctx);
+    }
+
+    return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+int ggml_cpu_has_avx(void) {
+#if defined(__AVX__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx2(void) {
+#if defined(__AVX2__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512(void) {
+#if defined(__AVX512F__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_fma(void) {
+#if defined(__FMA__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_neon(void) {
+#if defined(__ARM_NEON)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_arm_fma(void) {
+#if defined(__ARM_FEATURE_FMA)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_f16c(void) {
+#if defined(__F16C__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_fp16_va(void) {
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_wasm_simd(void) {
+#if defined(__wasm_simd128__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_blas(void) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_sse3(void) {
+#if defined(__SSE3__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_vsx(void) {
+#if defined(__POWER9_VECTOR__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/ggml.h b/ggml.h
new file mode 100644 (file)
index 0000000..7ce655c
--- /dev/null
+++ b/ggml.h
@@ -0,0 +1,758 @@
+#pragma once
+
+//
+// GGML Tensor Library
+//
+// This documentation is still a work in progress.
+// If you wish some specific topics to be covered, feel free to drop a comment:
+//
+//   https://github.com/ggerganov/whisper.cpp/issues/40
+//
+// ## Overview
+//
+// This library implements:
+//
+//  - a set of tensor operations
+//  - automatic differentiation
+//  - basic optimization algorithms
+//
+// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
+// but is not limited to, the following:
+//
+//  - linear regression
+//  - support vector machines
+//  - neural networks
+//
+// The library allows the user to define a certain function using the available tensor operations. This function
+// definition is represented internally via a computation graph. Each tensor operation in the function definition
+// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
+// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
+// using one of the available optimization algorithms.
+//
+// For example, here we define the function: f(x) = a*x^2 + b
+//
+//   {
+//       struct ggml_init_params params = {
+//           .mem_size   = 16*1024*1024,
+//           .mem_buffer = NULL,
+//       };
+//
+//       // memory allocation happens here
+//       struct ggml_context * ctx = ggml_init(params);
+//
+//       struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+//
+//       ggml_set_param(ctx, x); // x is an input variable
+//
+//       struct ggml_tensor * a  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+//       struct ggml_tensor * b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+//       struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
+//       struct ggml_tensor * f  = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
+//
+//       ...
+//   }
+//
+// Notice that the function definition above does not involve any actual computation. The computation is performed only
+// when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
+//
+//   {
+//       ...
+//
+//       struct ggml_cgraph gf = ggml_build_forward(f);
+//
+//       // set the input variable and parameter values
+//       ggml_set_f32(x, 2.0f);
+//       ggml_set_f32(a, 3.0f);
+//       ggml_set_f32(b, 4.0f);
+//
+//       ggml_graph_compute(ctx0, &gf);
+//
+//       printf("f = %f\n", ggml_get_f32_1d(f, 0));
+//
+//       ...
+//   }
+//
+// The actual computation is performed in the ggml_graph_compute() function.
+//
+// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
+// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
+// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
+// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
+// actually needed.
+//
+// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
+// differentiation and optimization algorithms.
+//
+// The described approach allows to define the function graph once and then compute its forward or backward graphs
+// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
+// the user can avoid the memory allocation overhead at runtime.
+//
+// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
+// citizens, but in theory the library can be extended to support FP8 and integer data types.
+//
+// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
+// and binary operations. Most of the available operations fall into one of these two categories. With time, it became
+// clear that the library needs to support more complex operations. The way to support these operations is not clear
+// yet, but a few examples are demonstrated in the following operations:
+//
+//   - ggml_permute()
+//   - ggml_conv_1d_1s()
+//   - ggml_conv_1d_2s()
+//
+// For each tensor operator, the library implements a forward and backward computation function. The forward function
+// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
+// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
+// calculus class, or watch the following video:
+//
+//   What is Automatic Differentiation?
+//   https://www.youtube.com/watch?v=wG_nF1awSSY
+//
+//
+// ## Tensor data (struct ggml_tensor)
+//
+// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
+// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
+// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
+//
+//   {
+//       struct ggml_tensor * c = ggml_add(ctx, a, b);
+//
+//       assert(c->src[0] == a);
+//       assert(c->src[1] == b);
+//   }
+//
+// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
+// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
+// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
+// permutation. All tensor operations have to take the stride into account and not assume that the tensor is
+// contiguous in memory.
+//
+// The data of the tensor is accessed via the "data" pointer. For example:
+//
+//   {
+//       struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3);
+//
+//       // a[1, 2] = 1.0f;
+//       *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f;
+//
+//       // a[2, 0] = 2.0f;
+//       *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f;
+//
+//       ...
+//   }
+//
+// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
+//
+// ## The matrix multiplication operator (ggml_mul_mat)
+//
+// TODO
+//
+//
+// ## Multi-threading
+//
+// TODO
+//
+//
+// ## Overview of ggml.c
+//
+// TODO
+//
+//
+// ## SIMD optimizations
+//
+// TODO
+//
+//
+// ## Debugging ggml
+//
+// TODO
+//
+//
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+#include <stdint.h>
+#include <stddef.h>
+#include <stdbool.h>
+
+#define GGML_MAX_DIMS     4
+#define GGML_MAX_NODES    4096
+#define GGML_MAX_PARAMS   16
+#define GGML_MAX_CONTEXTS 64
+#define GGML_MAX_OPT      4
+
+#ifdef __ARM_NEON
+// we use the built-in 16-bit float type
+typedef __fp16 ggml_fp16_t;
+#else
+typedef uint16_t ggml_fp16_t;
+#endif
+
+// convert FP16 <-> FP32
+float       ggml_fp16_to_fp32(ggml_fp16_t x);
+ggml_fp16_t ggml_fp32_to_fp16(float x);
+
+struct ggml_object;
+struct ggml_context;
+
+enum ggml_type {
+    GGML_TYPE_Q4_0,
+    GGML_TYPE_Q4_1,
+    GGML_TYPE_I8,
+    GGML_TYPE_I16,
+    GGML_TYPE_I32,
+    GGML_TYPE_F16,
+    GGML_TYPE_F32,
+    GGML_TYPE_COUNT,
+};
+
+// available tensor operations:
+enum ggml_op {
+    GGML_OP_NONE = 0,
+
+    GGML_OP_DUP,
+    GGML_OP_ADD,
+    GGML_OP_SUB,
+    GGML_OP_MUL,
+    GGML_OP_DIV,
+    GGML_OP_SQR,
+    GGML_OP_SQRT,
+    GGML_OP_SUM,
+    GGML_OP_MEAN,
+    GGML_OP_REPEAT,
+    GGML_OP_ABS,
+    GGML_OP_SGN,
+    GGML_OP_NEG,
+    GGML_OP_STEP,
+    GGML_OP_RELU,
+    GGML_OP_GELU,
+    GGML_OP_SILU,
+    GGML_OP_NORM, // normalize
+
+    GGML_OP_MUL_MAT,
+
+    GGML_OP_SCALE,
+    GGML_OP_CPY,
+    GGML_OP_RESHAPE,
+    GGML_OP_VIEW,
+    GGML_OP_PERMUTE,
+    GGML_OP_TRANSPOSE,
+    GGML_OP_GET_ROWS,
+    GGML_OP_DIAG_MASK_INF,
+    GGML_OP_SOFT_MAX,
+    GGML_OP_ROPE,
+    GGML_OP_CONV_1D_1S,
+    GGML_OP_CONV_1D_2S,
+
+    GGML_OP_FLASH_ATTN,
+    GGML_OP_FLASH_FF,
+
+    GGML_OP_COUNT,
+};
+
+// n-dimensional tensor
+struct ggml_tensor {
+    enum ggml_type type;
+
+    int    n_dims;
+    int    ne[GGML_MAX_DIMS]; // number of elements
+    size_t nb[GGML_MAX_DIMS]; // stride in bytes:
+                              // nb[0] = sizeof(type)
+                              // nb[1] = nb[0]   * ne[0] + padding
+                              // nb[i] = nb[i-1] * ne[i-1]
+
+    // compute data
+    enum ggml_op op;
+
+    bool is_param;
+
+    struct ggml_tensor * grad;
+    struct ggml_tensor * src0;
+    struct ggml_tensor * src1;
+    struct ggml_tensor * opt[GGML_MAX_OPT];
+
+    // thread scheduling
+    int n_tasks;
+
+    // performance
+    int     perf_runs;
+    int64_t perf_cycles;
+    int64_t perf_time_us;
+
+    void * data;
+    char padding[8];
+};
+
+// computation graph
+struct ggml_cgraph {
+    int n_nodes;
+    int n_leafs;
+    int n_threads;
+
+    size_t work_size;
+    struct ggml_tensor * work;
+
+    struct ggml_tensor * nodes[GGML_MAX_NODES];
+    struct ggml_tensor * grads[GGML_MAX_NODES];
+    struct ggml_tensor * leafs[GGML_MAX_NODES];
+
+    // performance
+    int     perf_runs;
+    int64_t perf_cycles;
+    int64_t perf_time_us;
+};
+
+// scratch buffer
+struct ggml_scratch {
+    size_t offs;
+    size_t size;
+    void * data;
+};
+
+struct ggml_init_params {
+    // memory pool
+    size_t mem_size;   // bytes
+    void * mem_buffer; // if NULL, memory will be allocated internally
+};
+
+void    ggml_time_init(void); // call this once at the beginning of the program
+int64_t ggml_time_ms(void);
+int64_t ggml_time_us(void);
+int64_t ggml_cycles(void);
+int64_t ggml_cycles_per_ms(void);
+
+void ggml_print_object (const struct ggml_object * obj);
+void ggml_print_objects(const struct ggml_context * ctx);
+
+int    ggml_nelements(const struct ggml_tensor * tensor);
+size_t ggml_nbytes   (const struct ggml_tensor * tensor);
+
+int    ggml_blck_size (enum ggml_type type);
+size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
+float  ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
+
+size_t ggml_element_size(const struct ggml_tensor * tensor);
+
+struct ggml_context * ggml_init(struct ggml_init_params params);
+void ggml_free(struct ggml_context * ctx);
+
+size_t ggml_used_mem(const struct ggml_context * ctx);
+
+size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
+
+struct ggml_tensor * ggml_new_tensor(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    n_dims,
+        const int *ne);
+
+struct ggml_tensor * ggml_new_tensor_1d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0);
+
+struct ggml_tensor * ggml_new_tensor_2d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0,
+        int    ne1);
+
+struct ggml_tensor * ggml_new_tensor_3d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0,
+        int    ne1,
+        int    ne2);
+
+struct ggml_tensor * ggml_new_tensor_4d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0,
+        int    ne1,
+        int    ne2,
+        int    ne3);
+
+struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
+struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
+
+struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
+struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
+
+struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
+struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
+struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
+
+int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
+void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
+
+float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
+void  ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
+
+ void * ggml_get_data    (const struct ggml_tensor * tensor);
+float * ggml_get_data_f32(const struct ggml_tensor * tensor);
+
+//
+// operations on tensors with backpropagation
+//
+
+struct ggml_tensor * ggml_dup(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_add(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+struct ggml_tensor * ggml_sub(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+struct ggml_tensor * ggml_mul(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+struct ggml_tensor * ggml_div(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+struct ggml_tensor * ggml_sqr(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_sqrt(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// return scalar
+// TODO: compute sum along rows
+struct ggml_tensor * ggml_sum(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// mean along rows
+struct ggml_tensor * ggml_mean(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// if a is the same shape as b, and a is not parameter, return a
+// otherwise, return a new tensor: repeat(a) to fit in b
+struct ggml_tensor * ggml_repeat(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+struct ggml_tensor * ggml_abs(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_sgn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_neg(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_step(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_relu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// TODO: double-check this computation is correct
+struct ggml_tensor * ggml_gelu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_silu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// normalize along rows
+// TODO: eps is hardcoded to 1e-5 for now
+struct ggml_tensor * ggml_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// A: m rows, n columns
+// B: p rows, n columns (i.e. we transpose it internally)
+// result is m columns, p rows
+struct ggml_tensor * ggml_mul_mat(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+//
+// operations on tensors without backpropagation
+//
+
+// in-place, returns view(a)
+struct ggml_tensor * ggml_scale(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+// a -> b, return view(b)
+struct ggml_tensor * ggml_cpy(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+// return view(a), b specifies the new shape
+// TODO: when we start computing gradient, make a copy instead of view
+struct ggml_tensor * ggml_reshape(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+// return view(a)
+// TODO: when we start computing gradient, make a copy instead of view
+struct ggml_tensor * ggml_reshape_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1);
+
+// return view(a)
+// TODO: when we start computing gradient, make a copy instead of view
+struct ggml_tensor * ggml_reshape_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1,
+        int                   ne2);
+
+// offset in bytes
+struct ggml_tensor * ggml_view_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        size_t                offset);
+
+struct ggml_tensor * ggml_view_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1,
+        size_t                nb1, // row stride in bytes
+        size_t                offset);
+
+struct ggml_tensor * ggml_permute(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   axis0,
+        int                   axis1,
+        int                   axis2,
+        int                   axis3);
+
+// alias for ggml_permute(ctx, a, 1, 0, 2, 3)
+struct ggml_tensor * ggml_transpose(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_get_rows(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+// set elements above the diagonal to -INF
+// in-place, returns view(a)
+struct ggml_tensor * ggml_diag_mask_inf(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past);
+
+// in-place, returns view(a)
+struct ggml_tensor * ggml_soft_max(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// rotary position embedding
+// in-place, returns view(a)
+// if mode == 1, skip n_past elements
+// TODO: avoid creating a new tensor every time
+struct ggml_tensor * ggml_rope(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        int                   mode);
+
+// padding = 1
+// TODO: we don't support extra parameters for now
+//       that's why we are hard-coding the stride, padding, and dilation
+//       not great ..
+struct ggml_tensor * ggml_conv_1d_1s(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+struct ggml_tensor * ggml_conv_1d_2s(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+struct ggml_tensor * ggml_flash_attn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        bool                  masked);
+
+struct ggml_tensor * ggml_flash_ff(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b0,
+        struct ggml_tensor  * b1,
+        struct ggml_tensor  * c0,
+        struct ggml_tensor  * c1);
+
+//
+// automatic differentiation
+//
+
+void ggml_set_param(
+        struct ggml_context * ctx,
+        struct ggml_tensor * tensor);
+
+void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
+
+struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
+struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
+
+void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+void ggml_graph_reset  (struct ggml_cgraph * cgraph);
+
+// print info and performance information for the graph
+void ggml_graph_print(const struct ggml_cgraph * cgraph);
+
+// dump the graph into a file using the dot format
+void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
+
+//
+// optimization
+//
+
+// optimization methods
+enum ggml_opt_type {
+    GGML_OPT_ADAM,
+    GGML_OPT_LBFGS,
+};
+
+// linesearch methods
+enum ggml_linesearch {
+    GGML_LINESEARCH_DEFAULT = 1,
+
+    GGML_LINESEARCH_BACKTRACKING_ARMIJO       = 0,
+    GGML_LINESEARCH_BACKTRACKING_WOLFE        = 1,
+    GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
+};
+
+// optimization return values
+enum ggml_opt_result {
+    GGML_OPT_OK = 0,
+    GGML_OPT_DID_NOT_CONVERGE,
+    GGML_OPT_NO_CONTEXT,
+    GGML_OPT_INVALID_WOLFE,
+    GGML_OPT_FAIL,
+
+    GGML_LINESEARCH_FAIL = -128,
+    GGML_LINESEARCH_MINIMUM_STEP,
+    GGML_LINESEARCH_MAXIMUM_STEP,
+    GGML_LINESEARCH_MAXIMUM_ITERATIONS,
+    GGML_LINESEARCH_INVALID_PARAMETERS,
+};
+
+// optimization parameters
+//
+//   see ggml.c (ggml_opt_default_params) for default values
+//
+struct ggml_opt_params {
+    enum ggml_opt_type type;
+
+    int n_threads;
+
+    // delta-based convergence test
+    //
+    //   if past == 0 - disabled
+    //   if past > 0:
+    //     stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
+    //
+    int past;
+    float delta;
+
+    // maximum number of iterations without improvement
+    //
+    //   if 0 - disabled
+    //   if > 0:
+    //     assume convergence if no cost improvement in this number of iterations
+    //
+    int max_no_improvement;
+
+    bool print_forward_graph;
+    bool print_backward_graph;
+
+    // ADAM parameters
+    struct {
+        int n_iter;
+
+        float alpha; // learning rate
+        float beta1;
+        float beta2;
+        float eps;   // epsilon for numerical stability
+        float eps_f; // epsilon for convergence test
+        float eps_g; // epsilon for convergence test
+    } adam;
+
+    // LBFGS parameters
+    struct {
+        int m; // number of corrections to approximate the inv. Hessian
+        int n_iter;
+        int max_linesearch;
+
+        float eps;      // convergence tolerance
+        float ftol;     // line search tolerance
+        float wolfe;
+        float min_step;
+        float max_step;
+
+        enum ggml_linesearch linesearch;
+    } lbfgs;
+};
+
+struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
+
+// optimize the function defined by the tensor f
+enum ggml_opt_result ggml_opt(
+        struct ggml_context * ctx,
+        struct ggml_opt_params params,
+        struct ggml_tensor * f);
+
+//
+// system info
+//
+
+int ggml_cpu_has_avx(void);
+int ggml_cpu_has_avx2(void);
+int ggml_cpu_has_avx512(void);
+int ggml_cpu_has_fma(void);
+int ggml_cpu_has_neon(void);
+int ggml_cpu_has_arm_fma(void);
+int ggml_cpu_has_f16c(void);
+int ggml_cpu_has_fp16_va(void);
+int ggml_cpu_has_wasm_simd(void);
+int ggml_cpu_has_blas(void);
+int ggml_cpu_has_sse3(void);
+int ggml_cpu_has_vsx(void);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/main.cpp b/main.cpp
new file mode 100644 (file)
index 0000000..fb9eb17
--- /dev/null
+++ b/main.cpp
@@ -0,0 +1,750 @@
+#include "ggml.h"
+
+#include "utils.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+
+// default hparams (LLaMA 7B)
+struct llama_hparams {
+    int32_t n_vocab = 32000;
+    int32_t n_ctx   = 512;   // this is provided as user input?
+    int32_t n_embd  = 4096;
+    int32_t n_mult  = 256;
+    int32_t n_head  = 32;
+    int32_t n_layer = 32;
+    int32_t n_rot   = 64;
+    int32_t f16     = 1;
+};
+
+struct llama_layer {
+    // normalization
+    struct ggml_tensor * attention_norm;
+
+    // attention
+    struct ggml_tensor * wq;
+    struct ggml_tensor * wk;
+    struct ggml_tensor * wv;
+    struct ggml_tensor * wo;
+
+    // normalization
+    struct ggml_tensor * ffn_norm;
+
+    // ff
+    struct ggml_tensor * w1;
+    struct ggml_tensor * w2;
+    struct ggml_tensor * w3;
+};
+
+struct llama_model {
+    llama_hparams hparams;
+
+    struct ggml_tensor * tok_embeddings;
+
+    struct ggml_tensor * norm;
+    struct ggml_tensor * output;
+
+    std::vector<llama_layer> layers;
+
+    // key + value memory
+    struct ggml_tensor * memory_k;
+    struct ggml_tensor * memory_v;
+
+    //
+    struct ggml_context * ctx;
+    std::map<std::string, struct ggml_tensor *> tensors;
+};
+
+// load the model's weights from a file
+bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
+    printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
+
+    auto fin = std::ifstream(fname, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        fin.read((char *) &magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            return false;
+        }
+    }
+
+    int n_ff = 0;
+
+    // load hparams
+    {
+        auto & hparams = model.hparams;
+
+        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        //fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fin.read((char *) &hparams.n_mult,  sizeof(hparams.n_mult));
+        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fin.read((char *) &hparams.n_rot,   sizeof(hparams.n_rot));
+        fin.read((char *) &hparams.f16,     sizeof(hparams.f16));
+
+        hparams.n_ctx = n_ctx;
+
+        n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
+
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        printf("%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        printf("%s: n_mult  = %d\n", __func__, hparams.n_mult);
+        printf("%s: n_head  = %d\n", __func__, hparams.n_head);
+        printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+        printf("%s: n_rot   = %d\n", __func__, hparams.n_rot);
+        printf("%s: f16     = %d\n", __func__, hparams.f16);
+        printf("%s: n_ff    = %d\n", __func__, n_ff);
+    }
+
+    // load vocab
+    {
+        const int32_t n_vocab = model.hparams.n_vocab;
+
+        if (n_vocab != model.hparams.n_vocab) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
+            return false;
+        }
+
+        std::string word;
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            fin.read((char *) &len, sizeof(len));
+
+            word.resize(len);
+            fin.read((char *) word.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+
+            //if (i < 30000) {
+            //    printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
+            //}
+        }
+    }
+
+    // for the big tensors, we have the option to store the data in 16-bit floats or quantized
+    // in order to save memory and also to speed up the computation
+    ggml_type wtype = GGML_TYPE_COUNT;
+    switch (model.hparams.f16) {
+        case 0: wtype = GGML_TYPE_F32;  break;
+        case 1: wtype = GGML_TYPE_F16;  break;
+        case 2: wtype = GGML_TYPE_Q4_0; break;
+        case 3: wtype = GGML_TYPE_Q4_1; break;
+        default:
+                {
+                    fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
+                            __func__, fname.c_str(), model.hparams.f16);
+                    return false;
+                }
+    }
+
+    const ggml_type wtype2 = GGML_TYPE_F32;
+
+    auto & ctx = model.ctx;
+
+    size_t ctx_size = 0;
+
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // tok_embeddings
+
+        ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm
+
+        ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // output
+
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
+
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo
+
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm
+
+        ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1
+        ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
+        ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
+
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
+
+        ctx_size += (5 + 10*n_layer)*256; // object overhead
+
+        printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+    }
+
+    // create the ggml context
+    {
+        struct ggml_init_params params = {
+            .mem_size   = ctx_size,
+            .mem_buffer = NULL,
+        };
+
+        model.ctx = ggml_init(params);
+        if (!model.ctx) {
+            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+            return false;
+        }
+    }
+
+    // prepare memory for the weights
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        model.layers.resize(n_layer);
+
+        model.tok_embeddings = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+
+        model.norm   = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+        model.output = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+
+        // map by name
+        model.tensors["tok_embeddings.weight"] = model.tok_embeddings;
+
+        model.tensors["norm.weight"]   = model.norm;
+        model.tensors["output.weight"] = model.output;
+
+        for (int i = 0; i < n_layer; ++i) {
+            auto & layer = model.layers[i];
+
+            layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_embd);
+            layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_embd);
+            layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_embd);
+            layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_embd);
+
+            layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.w1 = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_ff);
+            layer.w2 = ggml_new_tensor_2d(ctx, wtype,   n_ff, n_embd);
+            layer.w3 = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_ff);
+
+            // map by name
+            model.tensors["layers." + std::to_string(i) + ".attention_norm.weight"] = layer.attention_norm;
+
+            model.tensors["layers." + std::to_string(i) + ".attention.wq.weight"] = layer.wq;
+            model.tensors["layers." + std::to_string(i) + ".attention.wk.weight"] = layer.wk;
+            model.tensors["layers." + std::to_string(i) + ".attention.wv.weight"] = layer.wv;
+            model.tensors["layers." + std::to_string(i) + ".attention.wo.weight"] = layer.wo;
+
+            model.tensors["layers." + std::to_string(i) + ".ffn_norm.weight"] = layer.ffn_norm;
+
+            model.tensors["layers." + std::to_string(i) + ".feed_forward.w1.weight"] = layer.w1;
+            model.tensors["layers." + std::to_string(i) + ".feed_forward.w2.weight"] = layer.w2;
+            model.tensors["layers." + std::to_string(i) + ".feed_forward.w3.weight"] = layer.w3;
+        }
+    }
+
+    // key + value memory
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+
+        const int n_mem      = n_layer*n_ctx;
+        const int n_elements = n_embd*n_mem;
+
+        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+
+        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
+
+        printf("%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
+    }
+
+    // load weights
+    {
+        int n_tensors = 0;
+        size_t total_size = 0;
+
+        printf("%s: ", __func__);
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ftype;
+
+            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+            fin.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+
+            if (fin.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = { 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            fin.read(&name[0], length);
+
+            if (model.tensors.find(name.data()) == model.tensors.end()) {
+                fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                return false;
+            }
+
+            auto tensor = model.tensors[name.data()];
+            if (ggml_nelements(tensor) != nelements) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                return false;
+            }
+
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+                fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                        __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
+                return false;
+            }
+
+            if (0) {
+                static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
+                printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
+            }
+
+            size_t bpe = 0;
+
+            switch (ftype) {
+                case 0: bpe = ggml_type_size(GGML_TYPE_F32);  break;
+                case 1: bpe = ggml_type_size(GGML_TYPE_F16);  break;
+                case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
+                case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
+                default:
+                        {
+                            fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
+                            return false;
+                        }
+            };
+
+            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                        __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+                return false;
+            }
+
+            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+
+            //printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
+            total_size += ggml_nbytes(tensor);
+            if (++n_tensors % 8 == 0) {
+                printf(".");
+                fflush(stdout);
+            }
+        }
+
+        printf(" done\n");
+
+        printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
+    }
+
+    fin.close();
+
+    return true;
+}
+
+// evaluate the transformer
+//
+//   - model:     the model
+//   - n_threads: number of threads to use
+//   - n_past:    the context size so far
+//   - embd_inp:  the embeddings of the tokens in the context
+//   - embd_w:    the predicted logits for the next token
+//
+// The GPT-J model requires about 16MB of memory per input token.
+//
+bool llama_eval(
+        const llama_model & model,
+        const int n_threads,
+        const int n_past,
+        const std::vector<gpt_vocab::id> & embd_inp,
+              std::vector<float>         & embd_w,
+              size_t                     & mem_per_token) {
+    const int N = embd_inp.size();
+
+    const auto & hparams = model.hparams;
+
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+    const int n_ctx   = hparams.n_ctx;
+    const int n_head  = hparams.n_head;
+    const int n_vocab = hparams.n_vocab;
+    const int n_rot   = hparams.n_rot;
+
+    const int d_key = n_embd/n_head;
+
+    static size_t buf_size = 256u*1024*1024;
+    static void * buf = malloc(buf_size);
+
+    if (mem_per_token > 0 && mem_per_token*N > buf_size) {
+        const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
+        //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
+
+        // reallocate
+        buf_size = buf_size_new;
+        buf = realloc(buf, buf_size);
+        if (buf == nullptr) {
+            fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
+            return false;
+        }
+    }
+
+    struct ggml_init_params params = {
+        .mem_size   = buf_size,
+        .mem_buffer = buf,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+    struct ggml_cgraph gf = { .n_threads = n_threads };
+
+    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
+
+    struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
+
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * inpSA = inpL;
+
+        struct ggml_tensor * cur;
+
+        // norm
+        {
+            cur = ggml_norm(ctx0, inpL);
+
+            // cur = attention_norm*cur
+            cur = ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
+                        cur);
+        }
+
+        // self-attention
+        {
+            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+
+            // store key and value to memory
+            if (N >= 1) {
+                struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
+
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        ggml_rope(ctx0,
+                            ggml_cpy(ctx0,
+                                Qcur,
+                                ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
+                            n_past, n_rot, 0),
+                        0, 2, 1, 3);
+
+            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_rope(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
+                                n_embd/n_head, n_head, n_past + N),
+                            n_past, n_rot, 1),
+                        0, 2, 1, 3);
+
+            // K * Q
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctx0,
+                        KQ,
+                        ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
+                        );
+
+            // KQ_masked = mask_past(KQ_scaled)
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+
+            // KQ = soft_max(KQ_masked)
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
+            struct ggml_tensor * V_trans =
+                ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+                            n_embd/n_head, n_head, n_past + N),
+                        1, 2, 0, 3);
+
+            // KQV = transpose(V) * KQ_soft_max
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            cur = ggml_cpy(ctx0,
+                    KQV_merged,
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+
+            // projection (no bias)
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].wo,
+                    cur);
+        }
+
+        struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
+
+        // feed-forward network
+        {
+            // norm
+            {
+                cur = ggml_norm(ctx0, inpFF);
+
+                // cur = ffn_norm*cur
+                cur = ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
+                        cur);
+            }
+
+            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
+                    model.layers[il].w3,
+                    cur);
+
+
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].w1,
+                    cur);
+
+            // SILU activation
+            cur = ggml_silu(ctx0, cur);
+
+            cur = ggml_mul(ctx0, cur, tmp);
+
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].w2,
+                    cur);
+        }
+
+        cur  = ggml_add(ctx0, cur, inpFF);
+
+        // input for next layer
+        inpL = cur;
+    }
+
+    // norm
+    {
+        inpL = ggml_norm(ctx0, inpL);
+
+        // inpL = norm*inpL
+        inpL = ggml_mul(ctx0,
+                    ggml_repeat(ctx0, model.norm, inpL),
+                    inpL);
+    }
+
+    // lm_head
+    {
+        inpL = ggml_mul_mat(ctx0, model.output, inpL);
+    }
+
+    // logits -> probs
+    //inpL = ggml_soft_max(ctx0, inpL);
+
+    // run the computation
+    ggml_build_forward_expand(&gf, inpL);
+    ggml_graph_compute       (ctx0, &gf);
+
+    //if (n_past%100 == 0) {
+    //    ggml_graph_print   (&gf);
+    //    ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
+    //}
+
+    //embd_w.resize(n_vocab*N);
+    //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+
+    // return result for just the last token
+    embd_w.resize(n_vocab);
+    memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+
+    if (mem_per_token == 0) {
+        mem_per_token = ggml_used_mem(ctx0)/N;
+    }
+    //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
+
+    ggml_free(ctx0);
+
+    return true;
+}
+
+int main(int argc, char ** argv) {
+    const int64_t t_main_start_us = ggml_time_us();
+
+    gpt_params params;
+    params.model = "models/llama-7B/ggml-model.bin";
+
+    if (gpt_params_parse(argc, argv, params) == false) {
+        return 1;
+    }
+
+    if (params.seed < 0) {
+        params.seed = time(NULL);
+    }
+
+    printf("%s: seed = %d\n", __func__, params.seed);
+
+    std::mt19937 rng(params.seed);
+    if (params.prompt.empty()) {
+        params.prompt = gpt_random_prompt(rng);
+    }
+
+    int64_t t_load_us = 0;
+
+    gpt_vocab vocab;
+    llama_model model;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!llama_model_load(params.model, model, vocab, 512)) {  // TODO: set context from user input ??
+            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+            return 1;
+        }
+
+        t_load_us = ggml_time_us() - t_start_us;
+    }
+
+    int n_past = 0;
+
+    int64_t t_sample_us  = 0;
+    int64_t t_predict_us = 0;
+
+    std::vector<float> logits;
+
+    // tokenize the prompt
+    std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
+
+    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
+
+    printf("\n");
+    printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
+    printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
+    for (int i = 0; i < (int) embd_inp.size(); i++) {
+        printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
+    }
+    printf("\n");
+    printf("sampling parameters: temp = %f, top_k = %d, top_p = %f\n", params.temp, params.top_k, params.top_p);
+    printf("\n\n");
+
+    std::vector<gpt_vocab::id> embd;
+
+    // determine the required inference memory per token:
+    size_t mem_per_token = 0;
+    llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
+
+    for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
+        // predict
+        if (embd.size() > 0) {
+            const int64_t t_start_us = ggml_time_us();
+
+            if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
+                printf("Failed to predict\n");
+                return 1;
+            }
+
+            t_predict_us += ggml_time_us() - t_start_us;
+        }
+
+        n_past += embd.size();
+        embd.clear();
+
+        if (i >= embd_inp.size()) {
+            // sample next token
+            const int   top_k = params.top_k;
+            const float top_p = params.top_p;
+            const float temp  = params.temp;
+
+            const int n_vocab = model.hparams.n_vocab;
+
+            gpt_vocab::id id = 0;
+
+            {
+                const int64_t t_start_sample_us = ggml_time_us();
+
+                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
+
+                t_sample_us += ggml_time_us() - t_start_sample_us;
+            }
+
+            // add it to the context
+            embd.push_back(id);
+        } else {
+            // if here, it means we are still processing the input prompt
+            for (int k = i; k < embd_inp.size(); k++) {
+                embd.push_back(embd_inp[k]);
+                if (embd.size() > params.n_batch) {
+                    break;
+                }
+            }
+            i += embd.size() - 1;
+        }
+
+        // display text
+        for (auto id : embd) {
+            printf("%s", vocab.id_to_token[id].c_str());
+        }
+        fflush(stdout);
+
+        // end of text token
+        if (embd.back() == 2) {
+            break;
+        }
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n\n");
+        printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
+        printf("%s:     load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
+        printf("%s:   sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
+        printf("%s:  predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+    }
+
+    ggml_free(model.ctx);
+
+    return 0;
+}
diff --git a/quantize.cpp b/quantize.cpp
new file mode 100644 (file)
index 0000000..0bc62db
--- /dev/null
@@ -0,0 +1,330 @@
+#include "ggml.h"
+
+#include "utils.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+#include <regex>
+
+// TODO: move somewhere else
+#define QK 32
+
+// default hparams (LLaMA76B)
+struct llama_hparams {
+    int32_t n_vocab = 32000;
+    int32_t n_ctx   = 512;   // this is provided as user input?
+    int32_t n_embd  = 4096;
+    int32_t n_mult  = 256;
+    int32_t n_head  = 32;
+    int32_t n_layer = 32;
+    int32_t n_rot   = 64;
+    int32_t f16     = 1;
+};
+
+
+// quantize a model
+bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) {
+    ggml_type type = GGML_TYPE_Q4_1;
+
+    switch (itype) {
+        case 2: type = GGML_TYPE_Q4_0; break;
+        case 3: type = GGML_TYPE_Q4_1; break;
+        default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1;
+    };
+
+    if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) {
+        fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type);
+        return false;
+    }
+
+    gpt_vocab vocab;
+
+    printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
+
+    auto finp = std::ifstream(fname_inp, std::ios::binary);
+    if (!finp) {
+        fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
+        return false;
+    }
+
+    auto fout = std::ofstream(fname_out, std::ios::binary);
+    if (!fout) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        finp.read((char *) &magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
+            return false;
+        }
+
+        fout.write((char *) &magic, sizeof(magic));
+    }
+
+    llama_hparams hparams;
+
+    // load hparams
+    {
+        finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        //finp.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        finp.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        finp.read((char *) &hparams.n_mult,  sizeof(hparams.n_mult));
+        finp.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        finp.read((char *) &hparams.n_rot,   sizeof(hparams.n_rot));
+        finp.read((char *) &hparams.f16,     sizeof(hparams.f16));
+
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        printf("%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        printf("%s: n_mult  = %d\n", __func__, hparams.n_mult);
+        printf("%s: n_head  = %d\n", __func__, hparams.n_head);
+        printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+        printf("%s: f16     = %d\n", __func__, hparams.f16);
+
+        fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        //fout.write((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fout.write((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fout.write((char *) &hparams.n_mult,  sizeof(hparams.n_mult));
+        fout.write((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fout.write((char *) &hparams.n_rot,   sizeof(hparams.n_rot));
+        fout.write((char *) &itype,           sizeof(hparams.f16));
+    }
+
+    // load vocab
+    {
+        const int32_t n_vocab = hparams.n_vocab;
+
+        if (n_vocab != hparams.n_vocab) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+                    __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
+            return false;
+        }
+
+        std::string word;
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            finp.read ((char *) &len, sizeof(len));
+            fout.write((char *) &len, sizeof(len));
+
+            word.resize(len);
+            finp.read ((char *) word.data(), len);
+            fout.write((char *) word.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+        }
+    }
+
+    // load weights
+    {
+        size_t total_size_org = 0;
+        size_t total_size_new = 0;
+
+        std::vector<float> work;
+
+        std::vector<uint8_t>     data_u8;
+        std::vector<ggml_fp16_t> data_f16;
+        std::vector<float>       data_f32;
+
+        std::vector<int64_t> hist_all(1 << 4, 0);
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ftype;
+
+            finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            finp.read(reinterpret_cast<char *>(&length), sizeof(length));
+            finp.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+
+            if (finp.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = { 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            finp.read (&name[0], length);
+
+            {
+                static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
+                printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
+            }
+
+            // regexes of tensor names to be quantized
+            const std::vector<std::string> k_names = {
+                ".*weight",
+            };
+
+            bool quantize = false;
+            for (const auto & s : k_names) {
+                if (std::regex_match(name, std::regex(s))) {
+                    quantize = true;
+                    break;
+                }
+            }
+
+            // quantize only 2D tensors
+            quantize &= (n_dims == 2);
+
+            if (quantize) {
+                if (ftype != 0 && ftype != 1) {
+                    fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype);
+                    return false;
+                }
+
+                if (ftype == 1) {
+                    data_f16.resize(nelements);
+                    finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
+                    data_f32.resize(nelements);
+                    for (int i = 0; i < nelements; ++i) {
+                        data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
+                    }
+                } else {
+                    data_f32.resize(nelements);
+                    finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
+                }
+
+                ftype = itype;
+            } else {
+                const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t);
+
+                data_u8.resize(nelements*bpe);
+                finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
+            }
+
+            fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fout.write(reinterpret_cast<char *>(&length), sizeof(length));
+            fout.write(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+            for (int i = 0; i < n_dims; ++i) {
+                fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+            }
+            fout.write(&name[0], length);
+
+            if (quantize) {
+                printf("quantizing .. ");
+                work.resize(nelements); // for quantization
+
+                size_t cur_size = 0;
+                std::vector<int64_t> hist_cur(1 << 4, 0);
+
+                switch (type) {
+                    case GGML_TYPE_Q4_0:
+                        {
+                            cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], QK, hist_cur.data());
+                        } break;
+                    case GGML_TYPE_Q4_1:
+                        {
+                            cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], QK, hist_cur.data());
+                        } break;
+                    default:
+                        {
+                            fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, type);
+                            return false;
+                        }
+                }
+
+                fout.write(reinterpret_cast<char *>(work.data()), cur_size);
+                total_size_new += cur_size;
+
+                printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);
+                for (int i = 0; i < hist_cur.size(); ++i) {
+                    hist_all[i] += hist_cur[i];
+                }
+
+                for (int i = 0; i < hist_cur.size(); ++i) {
+                    printf("%5.3f ", hist_cur[i] / (float)nelements);
+                }
+                printf("\n");
+            } else {
+                printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
+                fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
+                total_size_new += data_u8.size();
+            }
+
+            total_size_org += nelements * sizeof(float);
+        }
+
+        printf("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
+        printf("%s: quant size  = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
+
+        {
+            int64_t sum_all = 0;
+            for (int i = 0; i < hist_all.size(); ++i) {
+                sum_all += hist_all[i];
+            }
+
+            printf("%s: hist: ", __func__);
+            for (int i = 0; i < hist_all.size(); ++i) {
+                printf("%5.3f ", hist_all[i] / (float)sum_all);
+            }
+            printf("\n");
+        }
+    }
+
+    finp.close();
+    fout.close();
+
+    return true;
+}
+
+// usage:
+//  ./llama-quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type
+//
+int main(int argc, char ** argv) {
+    if (argc != 4) {
+        fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
+        fprintf(stderr, "  type = 2 - q4_0\n");
+        fprintf(stderr, "  type = 3 - q4_1\n");
+        return 1;
+    }
+
+    const std::string fname_inp = argv[1];
+    const std::string fname_out = argv[2];
+
+    const int itype = atoi(argv[3]);
+
+    const int64_t t_main_start_us = ggml_time_us();
+
+    int64_t t_quantize_us = 0;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!llama_model_quantize(fname_inp, fname_out, itype)) {
+            fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
+            return 1;
+        }
+
+        t_quantize_us = ggml_time_us() - t_start_us;
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n");
+        printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+    }
+
+    return 0;
+}
diff --git a/utils.cpp b/utils.cpp
new file mode 100644 (file)
index 0000000..70a2ac2
--- /dev/null
+++ b/utils.cpp
@@ -0,0 +1,478 @@
+#include "utils.h"
+
+#include <fstream>
+#include <regex>
+
+bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
+    for (int i = 1; i < argc; i++) {
+        std::string arg = argv[i];
+
+        if (arg == "-s" || arg == "--seed") {
+            params.seed = std::stoi(argv[++i]);
+        } else if (arg == "-t" || arg == "--threads") {
+            params.n_threads = std::stoi(argv[++i]);
+        } else if (arg == "-p" || arg == "--prompt") {
+            params.prompt = argv[++i];
+        } else if (arg == "-n" || arg == "--n_predict") {
+            params.n_predict = std::stoi(argv[++i]);
+        } else if (arg == "--top_k") {
+            params.top_k = std::stoi(argv[++i]);
+        } else if (arg == "--top_p") {
+            params.top_p = std::stof(argv[++i]);
+        } else if (arg == "--temp") {
+            params.temp = std::stof(argv[++i]);
+        } else if (arg == "-b" || arg == "--batch_size") {
+            params.n_batch = std::stoi(argv[++i]);
+        } else if (arg == "-m" || arg == "--model") {
+            params.model = argv[++i];
+        } else if (arg == "-h" || arg == "--help") {
+            gpt_print_usage(argc, argv, params);
+            exit(0);
+        } else {
+            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+            gpt_print_usage(argc, argv, params);
+            exit(0);
+        }
+    }
+
+    return true;
+}
+
+void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
+    fprintf(stderr, "usage: %s [options]\n", argv[0]);
+    fprintf(stderr, "\n");
+    fprintf(stderr, "options:\n");
+    fprintf(stderr, "  -h, --help            show this help message and exit\n");
+    fprintf(stderr, "  -s SEED, --seed SEED  RNG seed (default: -1)\n");
+    fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
+    fprintf(stderr, "  -p PROMPT, --prompt PROMPT\n");
+    fprintf(stderr, "                        prompt to start generation with (default: random)\n");
+    fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d)\n", params.n_predict);
+    fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k);
+    fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p);
+    fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp);
+    fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
+    fprintf(stderr, "  -m FNAME, --model FNAME\n");
+    fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
+    fprintf(stderr, "\n");
+}
+
+std::string gpt_random_prompt(std::mt19937 & rng) {
+    const int r = rng() % 10;
+    switch (r) {
+        case 0: return "So";
+        case 1: return "Once upon a time";
+        case 2: return "When";
+        case 3: return "The";
+        case 4: return "After";
+        case 5: return "If";
+        case 6: return "import";
+        case 7: return "He";
+        case 8: return "She";
+        case 9: return "They";
+        default: return "To";
+    }
+
+    return "The";
+}
+
+void replace(std::string & str, const std::string & needle, const std::string & replacement) {
+    size_t pos = 0;
+    while ((pos = str.find(needle, pos)) != std::string::npos) {
+        str.replace(pos, needle.length(), replacement);
+        pos += replacement.length();
+    }
+}
+
+std::map<std::string, int32_t> json_parse(const std::string & fname) {
+    std::map<std::string, int32_t> result;
+
+    // read file into string
+    std::string json;
+    {
+        std::ifstream ifs(fname);
+        if (!ifs) {
+            fprintf(stderr, "Failed to open %s\n", fname.c_str());
+            exit(1);
+        }
+
+        json = std::string((std::istreambuf_iterator<char>(ifs)),
+                (std::istreambuf_iterator<char>()));
+    }
+
+    if (json[0] != '{') {
+        return result;
+    }
+
+    // parse json
+    {
+        bool has_key  = false;
+        bool in_token = false;
+
+        std::string str_key = "";
+        std::string str_val = "";
+
+        int n = json.size();
+        for (int i = 1; i < n; ++i) {
+            if (!in_token) {
+                if (json[i] == ' ') continue;
+                if (json[i] == '"') {
+                    in_token = true;
+                    continue;
+                }
+            } else {
+                if (json[i] == '\\' && i+1 < n) {
+                    if (has_key == false) {
+                        str_key += json[i];
+                    } else {
+                        str_val += json[i];
+                    }
+                    ++i;
+                } else if (json[i] == '"') {
+                    if (has_key == false) {
+                        has_key = true;
+                        ++i;
+                        while (json[i] == ' ') ++i;
+                        ++i; // :
+                        while (json[i] == ' ') ++i;
+                        if (json[i] != '\"') {
+                            while (json[i] != ',' && json[i] != '}') {
+                                str_val += json[i++];
+                            }
+                            has_key = false;
+                        } else {
+                            in_token = true;
+                            continue;
+                        }
+                    } else {
+                        has_key = false;
+                    }
+
+                    ::replace(str_key, "\\u0120", " " ); // \u0120 -> space
+                    ::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
+                    ::replace(str_key, "\\\"",    "\""); // \\\"   -> "
+
+                    try {
+                        result[str_key] = std::stoi(str_val);
+                    } catch (...) {
+                        //fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
+
+                    }
+                    str_key = "";
+                    str_val = "";
+                    in_token = false;
+                    continue;
+                }
+                if (has_key == false) {
+                    str_key += json[i];
+                } else {
+                    str_val += json[i];
+                }
+            }
+        }
+    }
+
+    return result;
+}
+
+std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
+    std::vector<std::string> words;
+
+    // first split the text into words
+    {
+        std::string str = text;
+        std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
+
+        std::regex re(pat);
+        std::smatch m;
+
+        while (std::regex_search(str, m, re)) {
+            for (auto x : m) {
+                words.push_back(x);
+            }
+            str = m.suffix();
+        }
+    }
+
+    // find the longest tokens that form the words:
+    std::vector<gpt_vocab::id> tokens;
+    for (const auto & word : words) {
+        if (word.size() == 0) continue;
+
+        int i = 0;
+        int n = word.size();
+        while (i < n) {
+            int j = n;
+            while (j > i) {
+                auto it = vocab.token_to_id.find(word.substr(i, j-i));
+                if (it != vocab.token_to_id.end()) {
+                    tokens.push_back(it->second);
+                    i = j;
+                    break;
+                }
+                --j;
+            }
+            if (i == n) {
+                break;
+            }
+            if (j == i) {
+                auto sub = word.substr(i, 1);
+                if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
+                    tokens.push_back(vocab.token_to_id.at(sub));
+                } else {
+                    fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
+                }
+                ++i;
+            }
+        }
+    }
+
+    return tokens;
+}
+
+std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos) {
+    auto res = gpt_tokenize(vocab, text);
+
+    if (bos) {
+        res.insert(res.begin(), 1); // TODO: replace with vocab.bos
+    }
+
+    //std::vector<gpt_vocab::id> res;
+
+    //if (bos) {
+    //    res.push_back(1); // TODO: replace with vocab.bos
+    //}
+
+    // find the longest token that matches the text
+    //int pos = 0;
+    //while (true) {
+    //    int l = 0;
+    //    int t = 0;
+    //    for (const auto & kv : vocab.id_to_token) {
+    //        if (kv.second.size() < l) continue;
+    //        if (kv.second.size() > text.size() - pos) continue;
+    //        if (text.substr(pos, kv.second.size()) == kv.second) {
+    //            l = kv.second.size();
+    //            t = kv.first;
+    //        }
+    //    }
+
+    //    if (l == 0 && t != 13) {
+    //        break;
+    //    }
+
+    //    res.push_back(t);
+    //    pos += l;
+    //}
+
+    return res;
+}
+
+bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
+    printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
+
+    vocab.token_to_id = ::json_parse(fname);
+
+    for (const auto & kv : vocab.token_to_id) {
+        vocab.id_to_token[kv.second] = kv.first;
+    }
+
+    printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
+
+    // print the vocabulary
+    //for (auto kv : vocab.token_to_id) {
+    //    printf("'%s' -> %d\n", kv.first.data(), kv.second);
+    //}
+
+    return true;
+}
+
+gpt_vocab::id gpt_sample_top_k_top_p(
+        const gpt_vocab & vocab,
+        const float * logits,
+        int    top_k,
+        double top_p,
+        double temp,
+        std::mt19937 & rng) {
+    int n_logits = vocab.id_to_token.size();
+
+    std::vector<std::pair<double, gpt_vocab::id>> logits_id;
+    logits_id.reserve(n_logits);
+
+    {
+        const double scale = 1.0/temp;
+        for (int i = 0; i < n_logits; ++i) {
+            logits_id.push_back(std::make_pair(logits[i]*scale, i));
+        }
+    }
+
+    // find the top K tokens
+    std::partial_sort(
+            logits_id.begin(),
+            logits_id.begin() + top_k, logits_id.end(),
+            [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
+        return a.first > b.first;
+    });
+
+    logits_id.resize(top_k);
+
+    double maxl = -INFINITY;
+    for (const auto & kv : logits_id) {
+        maxl = std::max(maxl, kv.first);
+    }
+
+    // compute probs for the top K tokens
+    std::vector<double> probs;
+    probs.reserve(logits_id.size());
+
+    double sum = 0.0;
+    for (const auto & kv : logits_id) {
+        double p = exp(kv.first - maxl);
+        probs.push_back(p);
+        sum += p;
+    }
+
+    // normalize the probs
+    for (auto & p : probs) {
+        p /= sum;
+    }
+
+    if (top_p < 1.0f) {
+        double cumsum = 0.0f;
+        for (int i = 0; i < top_k; i++) {
+            cumsum += probs[i];
+            if (cumsum >= top_p) {
+                top_k = i + 1;
+                probs.resize(top_k);
+                logits_id.resize(top_k);
+                break;
+            }
+        }
+
+        cumsum = 1.0/cumsum;
+        for (int i = 0; i < (int) probs.size(); i++) {
+            probs[i] *= cumsum;
+        }
+    }
+
+    //printf("\n");
+    //for (int i = 0; i < (int) probs.size(); i++) {
+    //    printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
+    //}
+    //exit(0);
+
+    std::discrete_distribution<> dist(probs.begin(), probs.end());
+    int idx = dist(rng);
+
+    return logits_id[idx].second;
+}
+
+size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist) {
+    const int nb = k / qk;
+    const size_t row_size = nb*(sizeof(float) + sizeof(uint8_t)*qk/2);
+
+    assert(k % qk == 0);
+
+    uint8_t pp[qk/2];
+
+    char * pdst = (char *) dst;
+
+    for (int j = 0; j < n; j += k) {
+        float   * pd = (float *)   (pdst + (j/k)*row_size);
+        uint8_t * pb = (uint8_t *) (pd + nb);
+
+        for (int i = 0; i < nb; i++) {
+            float amax = 0.0f; // absolute max
+
+            {
+                for (int l = 0; l < qk; l++) {
+                    const float v = src[j + i*qk + l];
+                    amax = std::max(amax, fabsf(v));
+                }
+
+                const float d = amax / ((1 << 3) - 1);
+                const float id = d ? 1.0f/d : 0.0f;
+
+                pd[i] = d;
+
+                for (int l = 0; l < qk; l += 2) {
+                    const float v0 = (src[j + i*qk + l + 0])*id;
+                    const float v1 = (src[j + i*qk + l + 1])*id;
+
+                    const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
+                    const uint8_t vi1 = ((int8_t) (round(v1))) + 8;
+
+                    assert(vi0 >= 0 && vi0 < 16);
+                    assert(vi1 >= 0 && vi1 < 16);
+
+                    hist[vi0]++;
+                    hist[vi1]++;
+
+                    pp[l/2] = vi0 | (vi1 << 4);
+                }
+
+                memcpy(pb + i*qk/2, pp, sizeof(pp));
+            }
+        }
+    }
+
+    return (n/k)*row_size;
+}
+
+size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t * hist) {
+    const int nb = k / qk;
+    const size_t row_size = nb*(2*sizeof(float) + sizeof(uint8_t)*qk/2);
+
+    assert(k % qk == 0);
+
+    uint8_t pp[qk/2];
+
+    char * pdst = (char *) dst;
+
+    for (int j = 0; j < n; j += k) {
+        float   * pm = (float *)   (pdst + (j/k)*row_size);
+        float   * pd = (float *)   (pm + nb);
+        uint8_t * pb = (uint8_t *) (pd + nb);
+
+        //printf("n = %d, k = %d, nb = %d, row_size = %d, j = %d, pm = %p, pd = %p, pb = %p\n", n, k, nb, row_size, j, pm, pd, pb);
+
+        for (int i = 0; i < nb; i++) {
+            float min = std::numeric_limits<float>::max();
+            float max = std::numeric_limits<float>::min();
+
+            {
+                for (int l = 0; l < qk; l++) {
+                    const float v = src[j + i*qk + l];
+                    if (v < min) min = v;
+                    if (v > max) max = v;
+                }
+
+                const float d = (max - min) / ((1 << 4) - 1);
+                const float id = d ? 1.0f/d : 0.0f;
+
+                pm[i] = min;
+                pd[i] = d;
+
+                for (int l = 0; l < qk; l += 2) {
+                    const float v0 = (src[j + i*qk + l + 0] - min)*id;
+                    const float v1 = (src[j + i*qk + l + 1] - min)*id;
+
+                    const uint8_t vi0 = round(v0);
+                    const uint8_t vi1 = round(v1);
+
+                    assert(vi0 >= 0 && vi0 < 16);
+                    assert(vi1 >= 0 && vi1 < 16);
+
+                    hist[vi0]++;
+                    hist[vi1]++;
+
+                    pp[l/2] = vi0 | (vi1 << 4);
+                }
+
+                memcpy(pb + i*qk/2, pp, sizeof(pp));
+            }
+        }
+    }
+
+    return (n/k)*row_size;
+}
diff --git a/utils.h b/utils.h
new file mode 100644 (file)
index 0000000..d291964
--- /dev/null
+++ b/utils.h
@@ -0,0 +1,94 @@
+// Various helper functions and utilities
+
+#pragma once
+
+#include <string>
+#include <map>
+#include <vector>
+#include <random>
+#include <thread>
+
+//
+// CLI argument parsing
+//
+
+struct gpt_params {
+    int32_t seed      = -1; // RNG seed
+    int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
+    int32_t n_predict = 200; // new tokens to predict
+
+    // sampling parameters
+    int32_t top_k = 100;
+    float   top_p = 0.95f;
+    float   temp  = 0.8f;
+
+    int32_t n_batch = 8; // batch size for prompt processing
+
+    std::string model = "models/lamma-7B/ggml-model.bin"; // model path
+    std::string prompt;
+};
+
+bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
+
+void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
+
+std::string gpt_random_prompt(std::mt19937 & rng);
+
+//
+// Vocab utils
+//
+
+struct gpt_vocab {
+    using id    = int32_t;
+    using token = std::string;
+
+    std::map<token, id> token_to_id;
+    std::map<id, token> id_to_token;
+};
+
+void replace(std::string & str, const std::string & needle, const std::string & replacement);
+
+// poor-man's JSON parsing
+std::map<std::string, int32_t> json_parse(const std::string & fname);
+
+// split text into tokens
+//
+// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
+//
+// Regex (Python):
+// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
+//
+// Regex (C++):
+// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
+//
+std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
+
+// TODO: this is probably wrong, but I cannot figure out how this tokenizer works ..
+// ref: https://github.com/google/sentencepiece
+std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos);
+
+// load the tokens from encoder.json
+bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
+
+// sample next token given probabilities for each embedding
+//
+//   - consider only the top K tokens
+//   - from them, consider only the top tokens with cumulative probability > P
+//
+// TODO: not sure if this implementation is correct
+// TODO: temperature is not implemented
+//
+gpt_vocab::id gpt_sample_top_k_top_p(
+        const gpt_vocab & vocab,
+        const float * logits,
+        int    top_k,
+        double top_p,
+        double temp,
+        std::mt19937 & rng);
+
+//
+// Quantization
+//
+
+size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist);
+size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t * hist);