]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Initial release
authorGeorgi Gerganov <redacted>
Sun, 18 Sep 2022 17:11:11 +0000 (20:11 +0300)
committerGeorgi Gerganov <redacted>
Sun, 18 Sep 2022 17:11:11 +0000 (20:11 +0300)
34 files changed:
.gitignore [new file with mode: 0644]
CMakeLists.txt [new file with mode: 0644]
LICENSE [new file with mode: 0644]
README.md [new file with mode: 0644]
cmake/BuildTypes.cmake [new file with mode: 0644]
cmake/GitVars.cmake [new file with mode: 0644]
examples/CMakeLists.txt [new file with mode: 0644]
examples/gpt-2/CMakeLists.txt [new file with mode: 0644]
examples/gpt-2/README.md [new file with mode: 0644]
examples/gpt-2/convert-ckpt-to-ggml.py [new file with mode: 0644]
examples/gpt-2/download-ggml-model.sh [new file with mode: 0755]
examples/gpt-2/download-model.sh [new file with mode: 0755]
examples/gpt-2/main.cpp [new file with mode: 0644]
examples/gpt-j/CMakeLists.txt [new file with mode: 0644]
examples/gpt-j/README.md [new file with mode: 0644]
examples/gpt-j/convert-h5-to-ggml.py [new file with mode: 0644]
examples/gpt-j/download-ggml-model.sh [new file with mode: 0755]
examples/gpt-j/download-model.sh [new file with mode: 0755]
examples/gpt-j/main.cpp [new file with mode: 0644]
examples/utils.cpp [new file with mode: 0644]
examples/utils.h [new file with mode: 0644]
include/ggml/ggml.h [new file with mode: 0644]
src/CMakeLists.txt [new file with mode: 0644]
src/ggml.c [new file with mode: 0644]
tests/CMakeLists.txt [new file with mode: 0644]
tests/test-grad0.c [new file with mode: 0644]
tests/test-mul-mat0.c [new file with mode: 0644]
tests/test-vec0.c [new file with mode: 0644]
tests/test-vec1.c [new file with mode: 0644]
tests/test-vec2.c [new file with mode: 0644]
tests/test0.c [new file with mode: 0644]
tests/test1.c [new file with mode: 0644]
tests/test2.c [new file with mode: 0644]
tests/test3.c [new file with mode: 0644]

diff --git a/.gitignore b/.gitignore
new file mode 100644 (file)
index 0000000..58cfed0
--- /dev/null
@@ -0,0 +1,10 @@
+build/
+build-debug/
+build-*/
+
+compile_commands.json
+
+.exrc
+.cache
+
+src/arm_neon.h
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644 (file)
index 0000000..378dd76
--- /dev/null
@@ -0,0 +1,71 @@
+cmake_minimum_required (VERSION 3.0)
+project(ggml VERSION 0.1.0)
+
+set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
+set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
+set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
+
+if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
+    set(GGML_STANDALONE ON)
+    include(cmake/GitVars.cmake)
+    include(cmake/BuildTypes.cmake)
+else()
+    set(GGML_STANDALONE OFF)
+endif()
+
+# options
+
+option(GGML_ALL_WARNINGS            "ggml: enable all compiler warnings" ON)
+option(GGML_ALL_WARNINGS_3RD_PARTY  "ggml: enable all compiler warnings in 3rd party libs" OFF)
+
+option(GGML_SANITIZE_THREAD         "ggml: enable thread sanitizer" OFF)
+option(GGML_SANITIZE_ADDRESS        "ggml: enable address sanitizer" OFF)
+option(GGML_SANITIZE_UNDEFINED      "ggml: enable undefined sanitizer" OFF)
+
+option(GGML_BUILD_TESTS             "ggml: build tests"    ${GGML_STANDALONE})
+option(GGML_BUILD_EXAMPLES          "ggml: build examples" ${GGML_STANDALONE})
+
+# sanitizers
+
+if (GGML_SANITIZE_THREAD)
+    set(CMAKE_C_FLAGS   "${CMAKE_C_FLAGS} -fsanitize=thread")
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread")
+endif()
+
+if (GGML_SANITIZE_ADDRESS)
+    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}     -fsanitize=address -fno-omit-frame-pointer")
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer")
+endif()
+
+if (GGML_SANITIZE_UNDEFINED)
+    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}     -fsanitize=undefined")
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined")
+endif()
+
+#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math")
+#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
+
+# dependencies
+
+set(CMAKE_C_STANDARD   11)
+set(CMAKE_CXX_STANDARD 11)
+
+find_package(Threads REQUIRED)
+
+# main
+
+if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
+    set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
+    set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "RelWithDebInfo")
+endif ()
+
+add_subdirectory(src)
+
+if (GGML_BUILD_TESTS)
+    enable_testing()
+    add_subdirectory(tests)
+endif ()
+
+if (GGML_BUILD_EXAMPLES)
+    add_subdirectory(examples)
+endif ()
diff --git a/LICENSE b/LICENSE
new file mode 100644 (file)
index 0000000..fb7ff0c
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Georgi Gerganov
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644 (file)
index 0000000..550425b
--- /dev/null
+++ b/README.md
@@ -0,0 +1,51 @@
+# ggml
+
+Tensor library in C for machine learning
+
+## Features
+
+- Automatic differentiation (WIP)
+- 16-bit float support
+- ADAM and L-BFGS optimizers
+- Optimized for Arm64 architectures (i.e. MacBook M1) via NEON intrinsics
+- On x86 architectures utilzes AVX intrinsics
+- No third-party dependencies
+- Zero memory allocations during runtime
+
+## Local GPT inference
+
+Using ggml you can run [GPT-2](examples/gpt-2) and [GPT-J](examples/gpt-j) inference locally on your computer without any additional software or hardware. You don't even need to install python or any other third-party library.
+
+The example programs are implemented in C++. They run entirely on the CPU.
+
+Here is how to use them:
+
+```bash
+# Build ggml + examples
+git clone https://github.com/ggerganov/ggml
+cd ggml
+mkdir build && cd build
+cmake ..
+make -j4 gpt-2 gpt-j
+
+# Run the GPT-2 small 117M model
+../examples/gpt-2/download-ggml-model.sh 117M
+./bin/gpt-2 -m models/gpt-2-117M/ggml-model.bin -p "This is an example"
+
+# Run the GPT-J 6B model (requires 12GB disk space and 16GB CPU RAM)
+../examples/gpt-j/download-ggml-model.sh 6B
+./bin/gpt-j -m models/gpt-j-6B/ggml-model.bin -p "This is an example"
+```
+
+This is the inference speed for the different models on my MacBook M1 Pro:
+
+| Model | Size  | Time / Token |
+| ---   | ---   | ---    |
+| GPT-2 |  117M |   5 ms |
+| GPT-2 |  345M |  12 ms |
+| GPT-2 |  774M |  23 ms |
+| GPT-2 | 1558M |  42 ms |
+| ---   | ---   | ---    |
+| GPT-J |    6B | 125 ms |
+
+For more information, checkout the corresponding programs in the [examples](examples) folder.
diff --git a/cmake/BuildTypes.cmake b/cmake/BuildTypes.cmake
new file mode 100644 (file)
index 0000000..a9c7b6c
--- /dev/null
@@ -0,0 +1,54 @@
+# Add new build types
+
+# ReleaseGG - Release with enabled asserts
+
+SET(CMAKE_CXX_FLAGS_RELEASEGG
+    "-O3"
+    CACHE STRING "Flags used by the c++ compiler during release builds with enabled asserts."
+    FORCE )
+SET(CMAKE_C_FLAGS_RELEASEGG
+    "-O3"
+    CACHE STRING "Flags used by the compiler during release builds with enabled asserts."
+    FORCE )
+SET(CMAKE_EXE_LINKER_FLAGS_RELEASEGG
+    ""
+    CACHE STRING "Flags used for linking binaries during release builds with enabled asserts."
+    FORCE )
+SET(CMAKE_SHARED_LINKER_FLAGS_RELEASEGG
+    ""
+    CACHE STRING "Flags used by the shared libraries linker during release builds with enabled asserts."
+    FORCE )
+MARK_AS_ADVANCED(
+    CMAKE_CXX_FLAGS_RELEASEGG
+    CMAKE_C_FLAGS_RELEASEGG
+    CMAKE_EXE_LINKER_FLAGS_RELEASEGG
+    CMAKE_SHARED_LINKER_FLAGS_RELEASEGG )
+
+# RelWithDebInfoGG - RelWithDebInfo with enabled asserts
+
+SET(CMAKE_CXX_FLAGS_RELWITHDEBINFOGG
+    "-O2 -g"
+    CACHE STRING "Flags used by the c++ compiler during release builds with debug symbols and enabled asserts."
+    FORCE )
+SET(CMAKE_C_FLAGS_RELWITHDEBINFOGG
+    "-O2 -g"
+    CACHE STRING "Flags used by the compiler during release builds with debug symbols and enabled asserts."
+    FORCE )
+SET(CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG
+    ""
+    CACHE STRING "Flags used for linking binaries during release builds with debug symbols and enabled asserts."
+    FORCE )
+SET(CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG
+    ""
+    CACHE STRING "Flags used by the shared libraries linker during release builds with debug symbols and enabled asserts."
+    FORCE )
+MARK_AS_ADVANCED(
+    CMAKE_CXX_FLAGS_RELWITHDEBINFOGG
+    CMAKE_C_FLAGS_RELWITHDEBINFOGG
+    CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG
+    CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG )
+
+if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
+    set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
+    set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo" "ReleaseGG" "RelWithDebInfoGG")
+endif()
diff --git a/cmake/GitVars.cmake b/cmake/GitVars.cmake
new file mode 100644 (file)
index 0000000..1a4c24e
--- /dev/null
@@ -0,0 +1,22 @@
+find_package(Git)
+
+# the commit's SHA1
+execute_process(COMMAND
+    "${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8
+    WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
+    OUTPUT_VARIABLE GIT_SHA1
+    ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
+
+# the date of the commit
+execute_process(COMMAND
+    "${GIT_EXECUTABLE}" log -1 --format=%ad --date=local
+    WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
+    OUTPUT_VARIABLE GIT_DATE
+    ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
+
+# the subject of the commit
+execute_process(COMMAND
+    "${GIT_EXECUTABLE}" log -1 --format=%s
+    WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
+    OUTPUT_VARIABLE GIT_COMMIT_SUBJECT
+    ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
new file mode 100644 (file)
index 0000000..cdbcfad
--- /dev/null
@@ -0,0 +1,5 @@
+add_library(ggml_utils STATIC utils.cpp)
+target_include_directories(ggml_utils PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
+
+add_subdirectory(gpt-2)
+add_subdirectory(gpt-j)
diff --git a/examples/gpt-2/CMakeLists.txt b/examples/gpt-2/CMakeLists.txt
new file mode 100644 (file)
index 0000000..9960cfe
--- /dev/null
@@ -0,0 +1,6 @@
+#
+# gpt-2
+
+set(TEST_TARGET gpt-2)
+add_executable(${TEST_TARGET} main.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils)
diff --git a/examples/gpt-2/README.md b/examples/gpt-2/README.md
new file mode 100644 (file)
index 0000000..3543bb2
--- /dev/null
@@ -0,0 +1,126 @@
+# gpt-2
+
+This is a C++ example running GPT-2 inference using the [ggml](https://github.com/ggerganov/ggml) library.
+The enitre code of the example is in [main.cpp](main.cpp).
+
+The program runs on the CPU - no video card is required.
+
+The example supports the following models:
+
+| Model | Description  | Disk Size |
+| ---   | ---          | ---       |
+| 117M  | Small model  | 240 MB    |
+| 345M  | Medium model | 680 MB    |
+| 774M  | Large model  | 1.5 GB    |
+| 1558M | XL model     | 3.0 GB    |
+
+Sample performance on MacBook M1 Pro:
+
+| Model | Size  | Time / Token |
+| ---   | ---   | ---    |
+| GPT-2 |  117M |   5 ms |
+| GPT-2 |  345M |  12 ms |
+| GPT-2 |  774M |  23 ms |
+| GPT-2 | 1558M |  42 ms |
+
+Sample output:
+
+```
+$ ./bin/gpt-2 -h
+usage: ./bin/gpt-2 [options]
+
+options:
+  -h, --help            show this help message and exit
+  -s SEED, --seed SEED  RNG seed (default: -1)
+  -t N, --threads N     number of threads to use during computation (default: 8)
+  -p PROMPT, --prompt PROMPT
+                        prompt to start generation with (default: random)
+  -n N, --n_predict N   number of tokens to predict (default: 200)
+  --top_k N             top-k sampling (default: 40)
+  --top_p N             top-p sampling (default: 0.9)
+  --temp N              temperature (default: 1.0)
+  -b N, --batch_size N  batch size for prompt processing (default: 8)
+  -m FNAME, --model FNAME
+                        model path (default: models/gpt-2-117M/ggml-model.bin)
+
+$ ./bin/gpt-2
+gpt2_model_load: loading model from 'models/gpt-2-117M/ggml-model.bin'
+gpt2_model_load: n_vocab = 50257
+gpt2_model_load: n_ctx   = 1024
+gpt2_model_load: n_embd  = 768
+gpt2_model_load: n_head  = 12
+gpt2_model_load: n_layer = 12
+gpt2_model_load: f16     = 1
+gpt2_model_load: ggml ctx size = 311.12 MB
+gpt2_model_load: memory size =    72.00 MB, n_mem = 12288
+gpt2_model_load: model size  =   239.08 MB
+main: number of tokens in prompt = 1
+
+So this is going to be the end of the line for us.
+
+If the Dolphins continue to do their business, it's possible that the team could make a bid to bring in new defensive coordinator Scott Linehan.
+
+Linehan's job is a little daunting, but he's a great coach and an excellent coach. I don't believe we're going to make the playoffs.
+
+We're going to have to work hard to keep our heads down and get ready to go.<|endoftext|>
+
+main: mem per token =  2048612 bytes
+main:     load time =   106.32 ms
+main:   sample time =     7.10 ms
+main:  predict time =   506.40 ms / 5.06 ms per token
+main:    total time =   629.84 ms
+```
+
+## Downloading and converting the original models
+
+You can download the original model files using the [download-model.sh](download-model.sh) Bash script.
+The model is in Tensorflow format, so before using it with ggml, we need to convert it to appropriate format.
+This is done via the [convert-ckpt-to-ggml.py](convert-ckpt-to-ggml.py) python script.
+
+Here is the entire process for the GPT-2 117M model:
+
+```
+cd ggml/build
+../examples/gpt-2/download-model.sh 117M
+
+Downloading model 117M ...
+models/gpt-2-117M/checkpoint                      100%[=============================>]      77  --.-KB/s    in 0s
+models/gpt-2-117M/encoder.json                    100%[=============================>]   1018K  1.20MB/s    in 0.8s
+models/gpt-2-117M/hparams.json                    100%[=============================>]      90  --.-KB/s    in 0s
+models/gpt-2-117M/model.ckpt.data-00000-of-00001  100%[=============================>] 474.70M  1.21MB/s    in 8m 39s
+models/gpt-2-117M/model.ckpt.index                100%[=============================>]   5.09K  --.-KB/s    in 0s
+models/gpt-2-117M/model.ckpt.meta                 100%[=============================>] 460.11K   806KB/s    in 0.6s
+models/gpt-2-117M/vocab.bpe                       100%[=============================>] 445.62K   799KB/s    in 0.6s
+Done! Model '117M' saved in 'models/gpt-2-117M/'
+
+Run the convert-ckpt-to-ggml.py script to convert the model to ggml format.
+
+  python /Users/john/ggml/examples/gpt-2/convert-ckpt-to-ggml.py models/gpt-2-117M/
+
+```
+
+This conversion requires that you have python and Tensorflow installed on your computer.
+Still, if you want to avoid this, you can download the already converted ggml models as
+described below.
+
+## Downloading the ggml model directly
+
+For convenience, I will be hosting the converted ggml model files in order to make it easier to run the examples.
+This way, you can directly download a single binary file and start using it. No python or Tensorflow is required.
+
+Here is how to get the 117M ggml model:
+
+```
+cd ggml/build
+../examples/gpt-2/download-ggml-model.sh 117M
+
+Downloading ggml model 117M ...
+models/gpt-2-117M/ggml-model.bin         100%[===============================>] 239.58M  8.52MB/s    in 28s
+Done! Model '117M' saved in 'models/gpt-2-117M/ggml-model.bin'
+You can now use it like this:
+
+  $ ./bin/gpt-2 -m models/gpt-2-117M/ggml-model.bin -p "This is an example"
+
+```
+
+At some point, I might stop hosting these models. So in that case, simply revert to the manual process above.
diff --git a/examples/gpt-2/convert-ckpt-to-ggml.py b/examples/gpt-2/convert-ckpt-to-ggml.py
new file mode 100644 (file)
index 0000000..09824a8
--- /dev/null
@@ -0,0 +1,127 @@
+# Convert a model checkpoint to a ggml compatible file
+#
+# Load the model using TensorFlow.
+# Iterate over all variables and write them to a binary file.
+#
+# For each variable, write the following:
+#   - Number of dimensions (int)
+#   - Name length (int)
+#   - Dimensions (int[n_dims])
+#   - Name (char[name_length])
+#   - Data (float[n_dims])
+#
+# By default, the bigger matrices are converted to 16-bit floats.
+# This can be disabled by adding the "use-f32" CLI argument.
+#
+# At the start of the ggml file we write the model parameters
+# and vocabulary.
+#
+
+import sys
+import json
+import struct
+import numpy as np
+import tensorflow as tf
+
+# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a signficant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8+n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+if len(sys.argv) < 2:
+    print("Usage: convert-ckpt-to-ggml.py dir-model [use-f32]\n")
+    sys.exit(1)
+
+# output in the same directory as the model
+dir_model = sys.argv[1]
+fname_out = sys.argv[1] + "/ggml-model.bin"
+
+with open(dir_model + "/encoder.json", "r") as f:
+    encoder = json.load(f)
+
+with open(dir_model + "/hparams.json", "r") as f:
+    hparams = json.load(f)
+
+# use 16-bit or 32-bit floats
+use_f16 = True
+if len(sys.argv) > 2:
+    use_f16 = False
+    fname_out = sys.argv[1] + "/ggml-model-f32.bin"
+
+list_vars = tf.train.list_variables(dir_model)
+
+fout = open(fname_out, "wb")
+
+fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
+fout.write(struct.pack("i", hparams["n_vocab"]))
+fout.write(struct.pack("i", hparams["n_ctx"]))
+fout.write(struct.pack("i", hparams["n_embd"]))
+fout.write(struct.pack("i", hparams["n_head"]))
+fout.write(struct.pack("i", hparams["n_layer"]))
+fout.write(struct.pack("i", use_f16))
+
+byte_encoder = bytes_to_unicode()
+byte_decoder = {v:k for k, v in byte_encoder.items()}
+
+fout.write(struct.pack("i", len(encoder)))
+for key in encoder:
+    text = bytearray([byte_decoder[c] for c in key]).decode('utf-8', errors='replace').encode('utf-8')
+    fout.write(struct.pack("i", len(text)))
+    fout.write(text)
+
+for name, shape in list_vars:
+    print("Processing variable: " + name + " with shape: ", shape)
+
+    data = tf.train.load_variable(dir_model, name).squeeze()
+    n_dims = len(data.shape);
+
+    # ftype == 0 -> float32, ftype == 1 -> float16
+    ftype = 0;
+    if use_f16:
+        # match name:
+        #  "model/wte"
+        #  "model/h.*/attn/c_attn/w"
+        #  "model/h.*/attn/c_proj/w"
+        #  "model/h.*/mlp/c_fc/w"
+        #  "model/h.*/mlp/c_proj/w"
+        if name == "model/wte" or name[-2:] == "/w":
+            print("  Converting to float16")
+            data = data.astype(np.float16)
+            ftype = 1
+
+    # for efficiency - transpose the projection matrices
+    if name[-13:] == "/mlp/c_proj/w":
+        print("  Transposing")
+        data = data.transpose()
+
+    # header
+    str = name.encode('utf-8')
+    fout.write(struct.pack("iii", n_dims, len(str), ftype))
+    for i in range(n_dims):
+        fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+    fout.write(str);
+
+    # data
+    data.tofile(fout)
+
+fout.close()
+
+print("Done. Output file: " + fname_out)
+print("")
diff --git a/examples/gpt-2/download-ggml-model.sh b/examples/gpt-2/download-ggml-model.sh
new file mode 100755 (executable)
index 0000000..9708618
--- /dev/null
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+# This script downloads GPT-2 model files that have already been converted to ggml format.
+# This way you don't have to convert them yourself.
+#
+# If you want to download the original GPT-2 model files, use the "download-model.sh" script instead.
+
+ggml_path=$(dirname $(realpath $0))
+
+# GPT-2 models
+models=( "117M" "345M" "774M" "1558M" )
+
+# list available models
+function list_models {
+    printf "\n"
+    printf "  Available models:"
+    for model in "${models[@]}"; do
+        printf " $model"
+    done
+    printf "\n\n"
+}
+
+if [ "$#" -ne 1 ]; then
+    printf "Usage: $0 <model>\n"
+    list_models
+
+    exit 1
+fi
+
+model=$1
+
+if [[ ! " ${models[@]} " =~ " ${model} " ]]; then
+    printf "Invalid model: $model\n"
+    list_models
+
+    exit 1
+fi
+
+# download ggml model
+
+printf "Downloading ggml model $model ...\n"
+
+mkdir -p models/gpt-2-$model
+
+wget --quiet --show-progress -O models/gpt-2-$model/ggml-model.bin https://ggml.ggerganov.com/ggml-model-gpt-2-$model.bin
+
+if [ $? -ne 0 ]; then
+    printf "Failed to download ggml model $model \n"
+    printf "Please try again later or download the original GPT-2 model files and convert them yourself.\n"
+    exit 1
+fi
+
+printf "Done! Model '$model' saved in 'models/gpt-2-$model/ggml-model.bin'\n"
+printf "You can now use it like this:\n\n"
+printf "  $ ./bin/gpt-2 -m models/gpt-2-$model/ggml-model.bin -p \"This is an example\"\n"
+printf "\n"
diff --git a/examples/gpt-2/download-model.sh b/examples/gpt-2/download-model.sh
new file mode 100755 (executable)
index 0000000..f0c62f4
--- /dev/null
@@ -0,0 +1,48 @@
+#!/bin/bash
+
+ggml_path=$(dirname $(realpath $0))
+
+# GPT-2 models
+models=( "117M" "345M" "774M" "1558M" )
+
+# list available models
+function list_models {
+    printf "\n"
+    printf "  Available models:"
+    for model in "${models[@]}"; do
+        printf " $model"
+    done
+    printf "\n\n"
+}
+
+if [ "$#" -ne 1 ]; then
+    printf "Usage: $0 <model>\n"
+    list_models
+
+    exit 1
+fi
+
+model=$1
+
+if [[ ! " ${models[@]} " =~ " ${model} " ]]; then
+    printf "Invalid model: $model\n"
+    list_models
+
+    exit 1
+fi
+
+# download model
+
+printf "Downloading model $model ...\n"
+
+mkdir -p models/gpt-2-$model
+
+for file in checkpoint encoder.json hparams.json model.ckpt.data-00000-of-00001 model.ckpt.index model.ckpt.meta vocab.bpe; do
+    wget --quiet --show-progress -O models/gpt-2-$model/$file https://openaipublic.blob.core.windows.net/gpt-2/models/$model/$file
+done
+
+printf "Done! Model '$model' saved in 'models/gpt-2-$model/'\n\n"
+printf "Run the convert-ckpt-to-ggml.py script to convert the model to ggml format.\n"
+printf "\n"
+printf "  python $ggml_path/convert-ckpt-to-ggml.py models/gpt-2-$model/\n"
+printf "\n"
diff --git a/examples/gpt-2/main.cpp b/examples/gpt-2/main.cpp
new file mode 100644 (file)
index 0000000..b515685
--- /dev/null
@@ -0,0 +1,783 @@
+#include "ggml/ggml.h"
+
+#include "utils.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+
+// default hparams (GPT-2 117M)
+struct gpt2_hparams {
+    int32_t n_vocab = 50257;
+    int32_t n_ctx   = 1024;
+    int32_t n_embd  = 768;
+    int32_t n_head  = 12;
+    int32_t n_layer = 12;
+    int32_t f16     = 1;
+};
+
+struct gpt2_layer {
+    // normalization
+    struct ggml_tensor * ln_1_g;
+    struct ggml_tensor * ln_1_b;
+
+    struct ggml_tensor * ln_2_g;
+    struct ggml_tensor * ln_2_b;
+
+    // attention
+    struct ggml_tensor * c_attn_attn_w;
+    struct ggml_tensor * c_attn_attn_b;
+
+    struct ggml_tensor * c_attn_proj_w;
+    struct ggml_tensor * c_attn_proj_b;
+
+    // mlp
+    struct ggml_tensor * c_mlp_fc_w;
+    struct ggml_tensor * c_mlp_fc_b;
+
+    struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
+    struct ggml_tensor * c_mlp_proj_b;
+};
+
+struct gpt2_model {
+    gpt2_hparams hparams;
+
+    // normalization
+    struct ggml_tensor * ln_f_g;
+    struct ggml_tensor * ln_f_b;
+
+    struct ggml_tensor * wte; // position embedding
+    struct ggml_tensor * wpe; //    token embedding
+
+    std::vector<gpt2_layer> layers;
+
+    // key + value memory
+    struct ggml_tensor * memory_k;
+    struct ggml_tensor * memory_v;
+
+    //
+    struct ggml_context * ctx;
+    std::map<std::string, struct ggml_tensor *> tensors;
+};
+
+// load the model's weights from a file
+bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
+    printf("%s: loading model from '%s'\n", __func__, fname.c_str());
+
+    auto fin = std::ifstream(fname, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        fin.read((char *) &magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            return false;
+        }
+    }
+
+    // load hparams
+    {
+        auto & hparams = model.hparams;
+
+        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fin.read((char *) &hparams.f16,     sizeof(hparams.f16));
+
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        printf("%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        printf("%s: n_head  = %d\n", __func__, hparams.n_head);
+        printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+        printf("%s: f16     = %d\n", __func__, hparams.f16);
+    }
+
+    // load vocab
+    {
+        int32_t n_vocab = 0;
+        fin.read((char *) &n_vocab, sizeof(n_vocab));
+
+        if (n_vocab != model.hparams.n_vocab) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
+            return false;
+        }
+
+        std::string word;
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            fin.read((char *) &len, sizeof(len));
+
+            word.resize(len);
+            fin.read((char *) word.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+        }
+    }
+
+    // for the big tensors, we have the option to store the data in 16-bit floats
+    // in order to save memory and also to speed up the computation
+    const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+    auto & ctx = model.ctx;
+
+    size_t ctx_size = 0;
+
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
+        ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
+
+        ctx_size += n_vocab*n_embd*ggml_type_size(wtype);         // wte
+        ctx_size +=   n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
+
+        ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
+        ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
+
+        ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
+        ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
+
+        ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype));         // c_attn_attn_w
+        ctx_size += n_layer*(       3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
+
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype));           // c_attn_proj_w
+        ctx_size += n_layer*(       n_embd*ggml_type_size(GGML_TYPE_F32));   // c_attn_proj_b
+
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype));         // c_mlp_fc_w
+        ctx_size += n_layer*(       4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
+
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype));         // c_mlp_proj_w
+        ctx_size += n_layer*(         n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
+
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
+
+        ctx_size += (6 + 12*n_layer)*256; // object overhead
+
+        printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+    }
+
+    // create the ggml context
+    {
+        struct ggml_init_params params = {
+            .mem_size   = ctx_size,
+            .mem_buffer = NULL,
+        };
+
+        model.ctx = ggml_init(params);
+        if (!model.ctx) {
+            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+            return false;
+        }
+    }
+
+    // prepare memory for the weights
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        model.layers.resize(n_layer);
+
+        model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+        model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+        model.wte = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+        model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
+
+        // map by name
+        model.tensors["model/ln_f/g"] = model.ln_f_g;
+        model.tensors["model/ln_f/b"] = model.ln_f_b;
+
+        model.tensors["model/wte"] = model.wte;
+        model.tensors["model/wpe"] = model.wpe;
+
+        for (int i = 0; i < n_layer; ++i) {
+            auto & layer = model.layers[i];
+
+            layer.ln_1_g             = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_1_b             = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.ln_2_g             = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_2_b             = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.c_attn_attn_w      = ggml_new_tensor_2d(ctx, wtype,         3*n_embd, n_embd);
+            layer.c_attn_attn_b      = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
+
+            layer.c_attn_proj_w      = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd);
+            layer.c_attn_proj_b      = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.c_mlp_fc_w         = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);
+            layer.c_mlp_fc_b         = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
+
+            layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);
+            layer.c_mlp_proj_b       = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            // map by name
+            model.tensors["model/h" + std::to_string(i) + "/ln_1/g"]        = layer.ln_1_g;
+            model.tensors["model/h" + std::to_string(i) + "/ln_1/b"]        = layer.ln_1_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/ln_2/g"]        = layer.ln_2_g;
+            model.tensors["model/h" + std::to_string(i) + "/ln_2/b"]        = layer.ln_2_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"]    = layer.c_mlp_fc_w;
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"]    = layer.c_mlp_fc_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"]  = layer.c_mlp_proj_w_trans;
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"]  = layer.c_mlp_proj_b;
+        }
+    }
+
+    // key + value memory
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+
+        const int n_mem      = n_layer*n_ctx;
+        const int n_elements = n_embd*n_mem;
+
+        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+
+        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
+
+        printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
+    }
+
+    // load weights
+    {
+        size_t total_size = 0;
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ftype;
+
+            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+            fin.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+
+            if (fin.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = { 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            fin.read(&name[0], length);
+
+            if (model.tensors.find(name.data()) == model.tensors.end()) {
+                fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                return false;
+            }
+
+            auto tensor = model.tensors[name.data()];
+            if (ggml_nelements(tensor) != nelements) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                return false;
+            }
+
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+                fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                        __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
+                return false;
+            }
+
+            const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
+
+            if (nelements*bpe != ggml_nbytes(tensor)) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                        __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+                return false;
+            }
+
+            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+
+            //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
+            total_size += ggml_nbytes(tensor);
+        }
+
+        printf("%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
+    }
+
+    fin.close();
+
+    return true;
+}
+
+// evaluate the transformer
+//
+//   - model:     the model
+//   - n_threads: number of threads to use
+//   - n_past:    the context size so far
+//   - embd_inp:  the embeddings of the tokens in the context
+//   - embd_w:    the predicted probabilities of the next token
+//
+bool gpt2_eval(
+        const gpt2_model & model,
+        const int n_threads,
+        const int n_past,
+        const std::vector<gpt_vocab::id> & embd_inp,
+              std::vector<float>         & embd_w,
+              size_t                     & mem_per_token) {
+    const int N = embd_inp.size();
+
+    const auto & hparams = model.hparams;
+
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+    const int n_ctx   = hparams.n_ctx;
+    const int n_head  = hparams.n_head;
+    const int n_vocab = hparams.n_vocab;
+
+    const int d_key = n_embd/n_head;
+
+    static size_t buf_size = 256u*1024*1024;
+    static void * buf = malloc(buf_size);
+
+    if (mem_per_token > 0 && mem_per_token*N > buf_size) {
+        const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
+        //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
+
+        // reallocate
+        buf_size = buf_size_new;
+        buf = realloc(buf, buf_size);
+        if (buf == nullptr) {
+            fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
+            return false;
+        }
+    }
+
+    struct ggml_init_params params = {
+        .mem_size   = buf_size,
+        .mem_buffer = buf,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+    struct ggml_cgraph gf = { .n_threads = n_threads };
+
+    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
+
+    struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    for (int i = 0; i < N; ++i) {
+        ((int32_t *) position->data)[i] = n_past + i;
+    }
+
+    // wte + wpe
+    struct ggml_tensor * inpL =
+        ggml_add(ctx0,
+                ggml_get_rows(ctx0, model.wte, embd),
+                ggml_get_rows(ctx0, model.wpe, position));
+
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * cur;
+
+        // norm
+        {
+            // [ 768, N]
+            cur = ggml_norm(ctx0, inpL);
+
+            // cur = ln_1_g*cur + ln_1_b
+            // [ 768, N]
+            cur = ggml_add(ctx0,
+                    ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
+                        cur),
+                    ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
+        }
+
+        // attn
+        // [2304, 768] - model.layers[il].c_attn_attn_w
+        // [2304,   1] - model.layers[il].c_attn_attn_b
+        // [ 768,   N] - cur (in)
+        // [2304,   N] - cur (out)
+        //
+        // cur = attn_w*cur + attn_b
+        // [2304, N]
+        {
+            cur = ggml_mul_mat(ctx0,
+                    ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
+                    cur);
+        }
+
+        // self-attention
+        {
+            struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
+            struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
+            struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
+
+            // store key and value to memory
+            if (N >= 1) {
+                struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
+
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
+            // [64, N, 12]
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
+                            Qcur,
+                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
+                        0, 2, 1, 3);
+
+            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
+            // [64, n_past + N, 12]
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
+                            n_embd/n_head, n_head, n_past + N),
+                        0, 2, 1, 3);
+
+            // K * Q
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctx0,
+                        KQ,
+                        ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
+                        );
+
+            // KQ_masked = mask_past(KQ_scaled)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+
+            // KQ = soft_max(KQ_masked)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
+            // [n_past + N, 64, 12]
+            struct ggml_tensor * V_trans =
+                ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+                            n_embd/n_head, n_head, n_past + N),
+                        1, 2, 0, 3);
+
+            // KQV = transpose(V) * KQ_soft_max
+            // [64, N, 12]
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            // [64, 12, N]
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            // [768, N]
+            cur = ggml_cpy(ctx0,
+                    KQV_merged,
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+        }
+
+        // projection
+        // [ 768, 768] - model.layers[il].c_attn_proj_w
+        // [ 768,   1] - model.layers[il].c_attn_proj_b
+        // [ 768,   N] - cur (in)
+        // [ 768,   N] - cur (out)
+        //
+        // cur = proj_w*cur + proj_b
+        // [768, N]
+        {
+            cur = ggml_mul_mat(ctx0,
+                    ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
+                    cur);
+        }
+
+        // add the input
+        cur = ggml_add(ctx0, cur, inpL);
+
+        struct ggml_tensor * inpFF = cur;
+
+        // feed-forward network
+        {
+            // norm
+            {
+                cur = ggml_norm(ctx0, inpFF);
+
+                // cur = ln_2_g*cur + ln_2_b
+                // [ 768, N]
+                cur = ggml_add(ctx0,
+                        ggml_mul(ctx0,
+                            ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
+                            cur),
+                        ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
+            }
+
+            // fully connected
+            // [3072, 768] - model.layers[il].c_mlp_fc_w
+            // [3072,   1] - model.layers[il].c_mlp_fc_b
+            // [ 768,   N] - cur (in)
+            // [3072,   N] - cur (out)
+            //
+            // cur = fc_w*cur + fc_b
+            // [3072, N]
+            cur = ggml_mul_mat(ctx0,
+                    ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
+                    cur);
+
+            // GELU activation
+            // [3072, N]
+            cur = ggml_gelu(ctx0, cur);
+
+            // projection
+            // [ 768, 3072] - model.layers[il].c_mlp_proj_w
+            // [ 768,    1] - model.layers[il].c_mlp_proj_b
+            // [3072,    N] - cur (in)
+            // [ 768,    N] - cur (out)
+            //
+            // cur = proj_w*cur + proj_b
+            // [768, N]
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_mlp_proj_w_trans,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
+                    cur);
+        }
+
+        // input for next layer
+        inpL = ggml_add(ctx0, cur, inpFF);
+    }
+
+    // norm
+    {
+        // [ 768, N]
+        inpL = ggml_norm(ctx0, inpL);
+
+        // inpL = ln_f_g*inpL + ln_f_b
+        // [ 768, N]
+        inpL = ggml_add(ctx0,
+                ggml_mul(ctx0,
+                    ggml_repeat(ctx0, model.ln_f_g, inpL),
+                    inpL),
+                ggml_repeat(ctx0, model.ln_f_b, inpL));
+    }
+
+    // inpL = WTE * inpL
+    // [ 768, 50257] - model.wte
+    // [ 768, N]     - inpL
+    inpL = ggml_mul_mat(ctx0, model.wte, inpL);
+
+    // to logits
+    inpL = ggml_soft_max(ctx0, inpL);
+
+    // run the computation
+    ggml_build_forward_expand(&gf, inpL);
+    ggml_graph_compute       (ctx0, &gf);
+
+    //if (n_past%100 == 0) {
+    //    ggml_graph_print   (&gf);
+    //    ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
+    //}
+
+    //embd_w.resize(n_vocab*N);
+    //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+
+    // return result for just the last token
+    embd_w.resize(n_vocab);
+    memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+
+    if (mem_per_token == 0) {
+        mem_per_token = ggml_used_mem(ctx0)/N;
+    }
+    //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
+
+    ggml_free(ctx0);
+
+    return true;
+}
+
+int main(int argc, char ** argv) {
+    const int64_t t_main_start_us = ggml_time_us();
+
+    gpt_params params;
+    params.model = "models/gpt-2-117M/ggml-model.bin";
+
+    if (gpt_params_parse(argc, argv, params) == false) {
+        return 1;
+    }
+
+    if (params.seed < 0) {
+        params.seed = time(NULL);
+    }
+
+    printf("%s: seed = %d\n", __func__, params.seed);
+
+    std::mt19937 rng(params.seed);
+    if (params.prompt.empty()) {
+        params.prompt = gpt_random_prompt(rng);
+    }
+
+    int64_t t_load_us = 0;
+
+    gpt_vocab vocab;
+    gpt2_model model;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!gpt2_model_load(params.model, model, vocab)) {
+            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+            return 1;
+        }
+
+        t_load_us = ggml_time_us() - t_start_us;
+    }
+
+    int n_past = 0;
+
+    int64_t t_sample_us  = 0;
+    int64_t t_predict_us = 0;
+
+    std::vector<float> embd_w;
+
+    // tokenize the prompt
+    std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
+
+    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
+
+    printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
+    printf("\n");
+
+    // submit the input prompt token-by-token
+    // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning
+    std::vector<gpt_vocab::id> embd;
+
+    // determine the required inference memory per token:
+    size_t mem_per_token = 0;
+    gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, embd_w, mem_per_token);
+
+    for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
+        // predict
+        if (embd.size() > 0) {
+            const int64_t t_start_us = ggml_time_us();
+
+            if (!gpt2_eval(model, params.n_threads, n_past, embd, embd_w, mem_per_token)) {
+                printf("Failed to predict\n");
+                return 1;
+            }
+
+            t_predict_us += ggml_time_us() - t_start_us;
+        }
+
+        n_past += embd.size();
+        embd.clear();
+
+        if (i >= embd_inp.size()) {
+            // sample next token
+            const int   top_k = params.top_k;
+            const float top_p = params.top_p;
+            const float temp  = params.temp;
+
+            const int n_vocab = model.hparams.n_vocab;
+
+            gpt_vocab::id id = 0;
+
+            {
+                const int64_t t_start_sample_us = ggml_time_us();
+
+                id = gpt_sample_top_k_top_p(vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, rng);
+
+                t_sample_us += ggml_time_us() - t_start_sample_us;
+            }
+
+            // add it to the context
+            embd.push_back(id);
+        } else {
+            // if here, it means we are still processing the input prompt
+            for (int k = i; k < embd_inp.size(); k++) {
+                embd.push_back(embd_inp[k]);
+                if (embd.size() > params.n_batch) {
+                    break;
+                }
+            }
+            i += embd.size() - 1;
+        }
+
+        // display text
+        for (auto id : embd) {
+            printf("%s", vocab.id_to_token[id].c_str());
+        }
+        fflush(stdout);
+
+        // end of text token
+        if (embd.back() == 50256) {
+            break;
+        }
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n\n");
+        printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
+        printf("%s:     load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
+        printf("%s:   sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
+        printf("%s:  predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+    }
+
+    ggml_free(model.ctx);
+
+    return 0;
+}
diff --git a/examples/gpt-j/CMakeLists.txt b/examples/gpt-j/CMakeLists.txt
new file mode 100644 (file)
index 0000000..4199a3f
--- /dev/null
@@ -0,0 +1,6 @@
+#
+# gpt-j
+
+set(TEST_TARGET gpt-j)
+add_executable(${TEST_TARGET} main.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils)
diff --git a/examples/gpt-j/README.md b/examples/gpt-j/README.md
new file mode 100644 (file)
index 0000000..608eac1
--- /dev/null
@@ -0,0 +1,155 @@
+# gpt-j
+
+Local GPT-J inference on your computer using C/C++
+
+No video card required. You just need to have 16 GB of RAM.
+
+For example, you can run this on a 16 GB MacBook M1.
+
+## Motivation
+
+The GPT-J 6B model is the open-source alternative to OpenAI's GPT-3. It's basically a neural network that
+allows you to generate coherent, human-like text given a certain context (prompt).
+
+The GPT-J model is quite big - the compact version of the model uses 16-bit floating point representation
+of the weights and is still 12 GB big. This means that in order to run inference on your computer, you
+would need to have a video card with at least 12 GB of video RAM. Alternatively, you can try to run the
+python implementations on the CPU, but that would probably not be very efficient as they are primarily
+optimized for running on a GPU (or at least this is my guess - I don't have much experience with python).
+
+Looking on the internet, I couldn't find a dedicated CPU implementation that would allow me to run the model
+without a high-end video card. So I decided to write my own inference using a custom build tensor library.
+The tensor library (called [ggml](https://github.com/ggerganov/ggml), written in C) is in early development
+stage, but it already allows me to run the GPT-J model.
+
+On my MacBook M1 Pro, I achieve an inference speed of about `125 ms/token` or about 2-3 words per second.
+
+Here is a sample run with prompt `int main(int argc, char ** argv) {`:
+
+```
+$ time ./bin/gpt-j -p "int main(int argc, char ** argv) {"
+
+gptj_model_load: loading model from 'models/gpt-j-6B/ggml-model.bin' - please wait ...
+gptj_model_load: n_vocab = 50400
+gptj_model_load: n_ctx   = 2048
+gptj_model_load: n_embd  = 4096
+gptj_model_load: n_head  = 16
+gptj_model_load: n_layer = 28
+gptj_model_load: n_rot   = 64
+gptj_model_load: f16     = 1
+gptj_model_load: ggml ctx size = 13334.86 MB
+gptj_model_load: memory_size =  1792.00 MB, n_mem = 57344
+gptj_model_load: ................................... done
+gptj_model_load: model size = 11542.79 MB / num tensors = 285
+main: number of tokens in prompt = 13
+
+int main(int argc, char ** argv) {
+    (void)argc;
+    (void)argv;
+
+    {
+        struct sockaddr_in addr;
+        int addrlen;
+        char * ip = "192.168.1.4";
+        int i;
+
+        if ( (addrlen = sizeof(addr)) == -1 )
+            return -1;
+
+        for (i = 0; i < 10; ++i) {
+            addr.sin_family = AF_INET;
+            addr.sin_addr.s_addr = inet_addr(ip);
+
+main: mem per token = 16430420 bytes
+main:     load time =  6211.48 ms
+main:   sample time =    13.74 ms
+main:  predict time = 26420.34 ms / 124.62 ms per token
+main:    total time = 33035.37 ms
+
+real   0m33.171s
+user   3m32.269s
+sys         0m3.686s
+
+$
+```
+
+It took ~6.2 seconds to load the model to memory. After that, it took ~26.4 seconds to generate 200
+tokens of what looks like to be the beginning of a networking program in C. Pretty cool!
+
+## Implementation details
+
+The high level implementation of the model is contained in the [main.cpp](main.cpp) file. The core
+computations are performed by the `ggml` library.
+
+The most performance critical part of the implementation is of course the matrix multiplication routine.
+99% of the time is spent here, so it is important to optimize this as much as possible.
+
+On Arm64, I utilize the 128-bit NEON intrinsics for 16-bit floating point operations:
+
+https://github.com/ggerganov/ggml/blob/1548ac6743c594cc920ccb3503444b0e2bdf4d56/src/ggml.c#L187-L243
+
+These instructions allow each core to operate simultaneously on 64 floating point numbers. I'm no expert
+in SIMD, but after quite some trials this was the most efficient code for dot product that I could come up
+with. Combined with the parallel computation on 8 CPU threads, I think I got close to the maximum performance
+that one could possibly get on the M1 CPU. Still, I'm curious to know if there is a more efficient way to
+implement this.
+
+One interesting property of the GPT-J transformer architecture is that it allows you to perform part
+of the inference in parallel - i.e. the Feed-forward layer can be computed in parallel to the Self-Attention
+layer:
+
+https://github.com/ggerganov/ggml/blob/1548ac6743c594cc920ccb3503444b0e2bdf4d56/examples/gpt-j/main.cpp#L507-L531
+
+So I thought why not bring in the M1 GPU to compute half of the neural network in parallel to the CPU.
+Thanks to the shared memory model, it was relatively easy to offload half of the computation to the GPU
+using [Metal Performance Shaders](https://developer.apple.com/documentation/metalperformanceshaders).
+However, to my surprise, I did not get any performance improvement at all. My conclusion was that the
+8-thread NEON CPU computation is basically saturating the memory bandwidth of the M1 and since the CPU
+and the GPU on the MacBook are sharing that bandwidth, it does not help to offload the computation to the
+GPU. Another observation was that the MPS GPU matrix multiplication using 16-bit floats had the same
+performance as the 8-thread NEON CPU implementation. Again, I explain this with a saturated memory channel.
+But of course, I could be totally wrong and somehow my implementation wasn't utilizing the resources 
+correctly.
+
+Another property of my implementation is that it does not perform any memory allocations once the model
+is loaded into memory. All required memory is allocated at the start of the program.
+
+## Usage
+
+If you want to give this a try and you are on Linux or Mac OS, simply follow these instructions:
+
+```bash
+# Clone the ggml library and build the gpt-j example
+git clone https://github.com/ggerganov/ggml
+cd ggml
+mkdir build && cd build
+cmake ..
+make -j4 gpt-j
+
+# Download the ggml-compatible GPT-J 6B model (requires 12GB disk space)
+../examples/gpt-j/download-ggml-model.sh 6B
+
+# Run the inference (requires 16GB of CPU RAM)
+./bin/gpt-j -m models/gpt-j-6B/ggml-model.bin -p "This is an example"
+```
+
+To run the `gpt-j` tool, you need the 12GB `ggml-model.bin` file which contains the GPT-J model in
+[ggml](https://github.com/ggerganov/ggml) format. In the instructions above, I download the binary file
+directly from one of my servers, using the [download-ggml-model.sh](download-ggml-model.sh) script.
+
+---
+
+Alternatively, you can perform the conversion yourself.
+
+First, you need to download the full GPT-J model from here: https://huggingface.co/EleutherAI/gpt-j-6B
+
+Note that the full model is quite big - about 72 GB. After you download it, you need to make the
+conversion using the [convert-h5-to-ggml.py](convert-h5-to-ggml.py) script. This will generate the
+`ggml-model.bin` file, which you can then use with the `gpt-j` program.
+
+## GPT-2
+
+I have also implemented a tool for CPU inference using the smaller GPT-2 models. They have worse
+quality compared to GPT-J, but are much faster to execute.
+
+Checkout the GPT-2 example here: [gpt-2](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2)
diff --git a/examples/gpt-j/convert-h5-to-ggml.py b/examples/gpt-j/convert-h5-to-ggml.py
new file mode 100644 (file)
index 0000000..a1efecb
--- /dev/null
@@ -0,0 +1,150 @@
+# Convert GPT-J-6B h5 transformer model to ggml format
+#
+# Load the model using GPTJForCausalLM.
+# Iterate over all variables and write them to a binary file.
+#
+# For each variable, write the following:
+#   - Number of dimensions (int)
+#   - Name length (int)
+#   - Dimensions (int[n_dims])
+#   - Name (char[name_length])
+#   - Data (float[n_dims])
+#
+# By default, the bigger matrices are converted to 16-bit floats.
+# This can be disabled by adding the "use-f32" CLI argument.
+#
+# At the start of the ggml file we write the model parameters
+# and vocabulary.
+#
+
+import sys
+import struct
+import json
+import torch
+import numpy as np
+
+from transformers import GPTJForCausalLM
+
+# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a signficant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8+n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+if len(sys.argv) < 2:
+    print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
+    sys.exit(1)
+
+# output in the same directory as the model
+dir_model = sys.argv[1]
+fname_out = sys.argv[1] + "/ggml-model.bin"
+
+with open(dir_model + "/vocab.json", "r") as f:
+    encoder = json.load(f)
+
+with open(dir_model + "/added_tokens.json", "r") as f:
+    encoder_added = json.load(f)
+
+with open(dir_model + "/config.json", "r") as f:
+    hparams = json.load(f)
+
+# use 16-bit or 32-bit floats
+use_f16 = True
+if len(sys.argv) > 2:
+    use_f16 = False
+    fname_out = sys.argv[1] + "/ggml-model-f32.bin"
+
+model = GPTJForCausalLM.from_pretrained(dir_model, low_cpu_mem_usage=True)
+#print (model)
+
+list_vars = model.state_dict()
+#print (list_vars)
+
+fout = open(fname_out, "wb")
+
+fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
+fout.write(struct.pack("i", hparams["vocab_size"]))
+fout.write(struct.pack("i", hparams["n_positions"]))
+fout.write(struct.pack("i", hparams["n_embd"]))
+fout.write(struct.pack("i", hparams["n_head"]))
+fout.write(struct.pack("i", hparams["n_layer"]))
+fout.write(struct.pack("i", hparams["rotary_dim"]))
+fout.write(struct.pack("i", use_f16))
+
+byte_encoder = bytes_to_unicode()
+byte_decoder = {v:k for k, v in byte_encoder.items()}
+
+fout.write(struct.pack("i", len(encoder) + len(encoder_added)))
+for key in encoder:
+    text = bytearray([byte_decoder[c] for c in key]).decode('utf-8', errors='replace').encode('utf-8')
+    fout.write(struct.pack("i", len(text)))
+    fout.write(text)
+
+for key in encoder_added:
+    text = bytearray([byte_decoder[c] for c in key]).decode('utf-8', errors='replace').encode('utf-8')
+    fout.write(struct.pack("i", len(text)))
+    fout.write(text)
+
+for name in list_vars.keys():
+    data = list_vars[name].squeeze().numpy()
+    print("Processing variable: " + name + " with shape: ", data.shape)
+
+    # we don't need these
+    if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"):
+        print("  Skipping variable: " + name)
+        continue
+
+    n_dims = len(data.shape);
+
+    # ftype == 0 -> float32, ftype == 1 -> float16
+    ftype = 0;
+    if use_f16:
+        if name[-7:] == ".weight" and n_dims == 2:
+            print("  Converting to float16")
+            data = data.astype(np.float16)
+            ftype = 1
+
+    # for efficiency - transpose these matrices:
+    #  "transformer.h.*.mlp.fc_in.weight
+    #  "transformer.h.*.attn.out_proj.weight
+    #  "transformer.h.*.attn.q_proj.weight"
+    #  "transformer.h.*.attn.k_proj.weight"
+    #  "transformer.h.*.attn.v_proj.weight"
+    if name.endswith(".mlp.fc_in.weight")     or \
+       name.endswith(".attn.out_proj.weight") or \
+       name.endswith(".attn.q_proj.weight")   or \
+       name.endswith(".attn.k_proj.weight")   or \
+       name.endswith(".attn.v_proj.weight"):
+        print("  Transposing")
+        data = data.transpose()
+
+    # header
+    str = name.encode('utf-8')
+    fout.write(struct.pack("iii", n_dims, len(str), ftype))
+    for i in range(n_dims):
+        fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+    fout.write(str);
+
+    # data
+    data.tofile(fout)
+
+fout.close()
+
+print("Done. Output file: " + fname_out)
+print("")
diff --git a/examples/gpt-j/download-ggml-model.sh b/examples/gpt-j/download-ggml-model.sh
new file mode 100755 (executable)
index 0000000..f6f5791
--- /dev/null
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+# This script downloads GPT-J model files that have already been converted to ggml format.
+# This way you don't have to convert them yourself.
+#
+# If you want to download the original GPT-J model files, use the "download-model.sh" script instead.
+
+ggml_path=$(dirname $(realpath $0))
+
+# GPT-J models
+models=( "6B" )
+
+# list available models
+function list_models {
+    printf "\n"
+    printf "  Available models:"
+    for model in "${models[@]}"; do
+        printf " $model"
+    done
+    printf "\n\n"
+}
+
+if [ "$#" -ne 1 ]; then
+    printf "Usage: $0 <model>\n"
+    list_models
+
+    exit 1
+fi
+
+model=$1
+
+if [[ ! " ${models[@]} " =~ " ${model} " ]]; then
+    printf "Invalid model: $model\n"
+    list_models
+
+    exit 1
+fi
+
+# download ggml model
+
+printf "Downloading ggml model $model ...\n"
+
+mkdir -p models/gpt-j-$model
+
+wget --quiet --show-progress -O models/gpt-j-$model/ggml-model.bin https://ggml.ggerganov.com/ggml-model-gpt-j-$model.bin
+
+if [ $? -ne 0 ]; then
+    printf "Failed to download ggml model $model \n"
+    printf "Please try again later or download the original GPT-J model files and convert them yourself.\n"
+    exit 1
+fi
+
+printf "Done! Model '$model' saved in 'models/gpt-j-$model/ggml-model.bin'\n"
+printf "You can now use it like this:\n\n"
+printf "  $ ./bin/gpt-j -m models/gpt-j-$model/ggml-model.bin -p \"This is an example\"\n"
+printf "\n"
diff --git a/examples/gpt-j/download-model.sh b/examples/gpt-j/download-model.sh
new file mode 100755 (executable)
index 0000000..c773baf
--- /dev/null
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+printf "To obtain the GPT-J 6B model files, please visit: https://huggingface.co/EleutherAI/gpt-j-6B\n\n"
+
+printf "The model is very big. For example, the reposirory above is 72GB in size.\n"
+printf "If you are sure that you want to clone it, simply run the following command:\n\n"
+
+printf " $ git clone https://huggingface.co/EleutherAI/gpt-j-6B models/gpt-j-6B\n\n"
+
+printf "Alternatively, use the 'download-ggml-model.sh' script to download a 12GB ggml version of the model.\n"
+printf "This version is enough to run inference using the ggml library.\n\n"
diff --git a/examples/gpt-j/main.cpp b/examples/gpt-j/main.cpp
new file mode 100644 (file)
index 0000000..1e724da
--- /dev/null
@@ -0,0 +1,723 @@
+#include "ggml/ggml.h"
+
+#include "utils.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+
+// default hparams (GPT-J 6B)
+struct gptj_hparams {
+    int32_t n_vocab = 50400;
+    int32_t n_ctx   = 2048;
+    int32_t n_embd  = 4096;
+    int32_t n_head  = 16;
+    int32_t n_layer = 28;
+    int32_t n_rot   = 64;
+    int32_t f16     = 1;
+};
+
+struct gptj_layer {
+    // normalization
+    struct ggml_tensor * ln_1_g;
+    struct ggml_tensor * ln_1_b;
+
+    // attention
+    struct ggml_tensor * c_attn_q_proj_w;
+    struct ggml_tensor * c_attn_k_proj_w;
+    struct ggml_tensor * c_attn_v_proj_w;
+
+    struct ggml_tensor * c_attn_proj_w;
+
+    // ff
+    struct ggml_tensor * c_mlp_fc_w;
+    struct ggml_tensor * c_mlp_fc_b;
+
+    struct ggml_tensor * c_mlp_proj_w_trans;
+    struct ggml_tensor * c_mlp_proj_b;
+};
+
+struct gptj_model {
+    gptj_hparams hparams;
+
+    // normalization
+    struct ggml_tensor * ln_f_g;
+    struct ggml_tensor * ln_f_b;
+
+    struct ggml_tensor * wte; // position embedding
+
+    struct ggml_tensor * lmh_g; // language model head
+    struct ggml_tensor * lmh_b; // language model bias
+
+    std::vector<gptj_layer> layers;
+
+    // key + value memory
+    struct ggml_tensor * memory_k;
+    struct ggml_tensor * memory_v;
+
+    //
+    struct ggml_context * ctx;
+    std::map<std::string, struct ggml_tensor *> tensors;
+};
+
+// load the model's weights from a file
+bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & vocab) {
+    printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
+
+    auto fin = std::ifstream(fname, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        fin.read((char *) &magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            return false;
+        }
+    }
+
+    // load hparams
+    {
+        auto & hparams = model.hparams;
+
+        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fin.read((char *) &hparams.n_rot,   sizeof(hparams.n_rot));
+        fin.read((char *) &hparams.f16,     sizeof(hparams.f16));
+
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        printf("%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        printf("%s: n_head  = %d\n", __func__, hparams.n_head);
+        printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+        printf("%s: n_rot   = %d\n", __func__, hparams.n_rot);
+        printf("%s: f16     = %d\n", __func__, hparams.f16);
+    }
+
+    // load vocab
+    {
+        int32_t n_vocab = 0;
+        fin.read((char *) &n_vocab, sizeof(n_vocab));
+
+        if (n_vocab != model.hparams.n_vocab) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
+            return false;
+        }
+
+        std::string word;
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            fin.read((char *) &len, sizeof(len));
+
+            word.resize(len);
+            fin.read((char *) word.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+        }
+    }
+
+    // for the big tensors, we have the option to store the data in 16-bit floats
+    // in order to save memory and also to speed up the computation
+    const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+    auto & ctx = model.ctx;
+
+    size_t ctx_size = 0;
+
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
+        ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
+
+        ctx_size += n_embd*n_vocab*ggml_type_size(wtype); // wte
+
+        ctx_size += n_embd*n_vocab*ggml_type_size(wtype);         // lmh_g
+        ctx_size +=        n_vocab*ggml_type_size(GGML_TYPE_F32); // lmh_b
+
+        ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
+        ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
+
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_q_proj_w
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_k_proj_w
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_v_proj_w
+
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
+
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype));         // c_mlp_fc_w
+        ctx_size += n_layer*(       4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
+
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype));         // c_mlp_proj_w_trans
+        ctx_size += n_layer*(         n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
+
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
+
+        ctx_size += (5 + 10*n_layer)*256; // object overhead
+
+        printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+    }
+
+    // create the ggml context
+    {
+        struct ggml_init_params params = {
+            .mem_size   = ctx_size,
+            .mem_buffer = NULL,
+        };
+
+        model.ctx = ggml_init(params);
+        if (!model.ctx) {
+            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+            return false;
+        }
+    }
+
+    // prepare memory for the weights
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        model.layers.resize(n_layer);
+
+        model.wte    = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+
+        model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+        model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+        model.lmh_g  = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+        model.lmh_b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab);
+
+        // map by name
+        model.tensors["transformer.wte.weight"] = model.wte;
+
+        model.tensors["transformer.ln_f.weight"] = model.ln_f_g;
+        model.tensors["transformer.ln_f.bias"]   = model.ln_f_b;
+
+        model.tensors["lm_head.weight"] = model.lmh_g;
+        model.tensors["lm_head.bias"]   = model.lmh_b;
+
+        for (int i = 0; i < n_layer; ++i) {
+            auto & layer = model.layers[i];
+
+            layer.ln_1_g                = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_1_b                = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.c_attn_q_proj_w       = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);
+            layer.c_attn_k_proj_w       = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);
+            layer.c_attn_v_proj_w       = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);
+
+            layer.c_attn_proj_w         = ggml_new_tensor_2d(ctx, wtype,           n_embd,   n_embd);
+
+            layer.c_mlp_fc_w            = ggml_new_tensor_2d(ctx, wtype,         4*n_embd,   n_embd);
+            layer.c_mlp_fc_b            = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
+
+            layer.c_mlp_proj_w_trans    = ggml_new_tensor_2d(ctx, wtype,         4*n_embd,   n_embd);
+            layer.c_mlp_proj_b          = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            // map by name
+            model.tensors["transformer.h." + std::to_string(i) + ".ln_1.weight"]          = layer.ln_1_g;
+            model.tensors["transformer.h." + std::to_string(i) + ".ln_1.bias"]            = layer.ln_1_b;
+
+            model.tensors["transformer.h." + std::to_string(i) + ".attn.q_proj.weight"]   = layer.c_attn_q_proj_w;
+            model.tensors["transformer.h." + std::to_string(i) + ".attn.k_proj.weight"]   = layer.c_attn_k_proj_w;
+            model.tensors["transformer.h." + std::to_string(i) + ".attn.v_proj.weight"]   = layer.c_attn_v_proj_w;
+
+            model.tensors["transformer.h." + std::to_string(i) + ".attn.out_proj.weight"] = layer.c_attn_proj_w;
+
+            model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.weight"]     = layer.c_mlp_fc_w;
+            model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.bias"]       = layer.c_mlp_fc_b;
+
+            model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.weight"]    = layer.c_mlp_proj_w_trans;
+            model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.bias"]      = layer.c_mlp_proj_b;
+        }
+    }
+
+    // key + value memory
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+
+        const int n_mem      = n_layer*n_ctx;
+        const int n_elements = n_embd*n_mem;
+
+        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+
+        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
+
+        printf("%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
+    }
+
+    // load weights
+    {
+        int n_tensors = 0;
+        size_t total_size = 0;
+
+        printf("%s: ", __func__);
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ftype;
+
+            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+            fin.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+
+            if (fin.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = { 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            fin.read(&name[0], length);
+
+            if (model.tensors.find(name.data()) == model.tensors.end()) {
+                fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                return false;
+            }
+
+            auto tensor = model.tensors[name.data()];
+            if (ggml_nelements(tensor) != nelements) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                return false;
+            }
+
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+                fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                        __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
+                return false;
+            }
+
+            const size_t bpe = tensor->type == GGML_TYPE_I8 ? 1 : (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
+
+            if (nelements*bpe != ggml_nbytes(tensor)) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                        __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+                return false;
+            }
+
+            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+
+            //printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
+            total_size += ggml_nbytes(tensor);
+            if (++n_tensors % 8 == 0) {
+                printf(".");
+                fflush(stdout);
+            }
+        }
+
+        printf(" done\n");
+
+        printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
+    }
+
+    fin.close();
+
+    return true;
+}
+
+// evaluate the transformer
+//
+//   - model:     the model
+//   - n_threads: number of threads to use
+//   - n_past:    the context size so far
+//   - embd_inp:  the embeddings of the tokens in the context
+//   - embd_w:    the predicted probabilities of the next token
+//
+// The GPT-J model requires about 16MB of memory per input token.
+//
+bool gptj_eval(
+        const gptj_model & model,
+        const int n_threads,
+        const int n_past,
+        const std::vector<gpt_vocab::id> & embd_inp,
+              std::vector<float>         & embd_w,
+              size_t                     & mem_per_token) {
+    const int N = embd_inp.size();
+
+    const auto & hparams = model.hparams;
+
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+    const int n_ctx   = hparams.n_ctx;
+    const int n_head  = hparams.n_head;
+    const int n_vocab = hparams.n_vocab;
+    const int n_rot   = hparams.n_rot;
+
+    const int d_key = n_embd/n_head;
+
+    static size_t buf_size = 256u*1024*1024;
+    static void * buf = malloc(buf_size);
+
+    if (mem_per_token > 0 && mem_per_token*N > buf_size) {
+        const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
+        //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
+
+        // reallocate
+        buf_size = buf_size_new;
+        buf = realloc(buf, buf_size);
+        if (buf == nullptr) {
+            fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
+            return false;
+        }
+    }
+
+    struct ggml_init_params params = {
+        .mem_size   = buf_size,
+        .mem_buffer = buf,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+    struct ggml_cgraph gf = { .n_threads = n_threads };
+
+    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
+
+    // wte
+    struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte, embd);
+
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * cur;
+
+        // norm
+        {
+            cur = ggml_norm(ctx0, inpL);
+
+            // cur = ln_1_g*cur + ln_1_b
+            cur = ggml_add(ctx0,
+                    ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
+                        cur),
+                    ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
+        }
+
+        struct ggml_tensor * inpSA = cur;
+
+        // self-attention
+        {
+            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, ggml_transpose(ctx0, model.layers[il].c_attn_q_proj_w), cur);
+            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, ggml_transpose(ctx0, model.layers[il].c_attn_k_proj_w), cur);
+            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, ggml_transpose(ctx0, model.layers[il].c_attn_v_proj_w), cur);
+
+            // store key and value to memory
+            if (N >= 1) {
+                struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
+
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        ggml_rope(ctx0,
+                            ggml_cpy(ctx0,
+                                Qcur,
+                                ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
+                            n_past, n_rot, 0),
+                        0, 2, 1, 3);
+
+            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_rope(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
+                                n_embd/n_head, n_head, n_past + N),
+                            n_past, n_rot, 1),
+                        0, 2, 1, 3);
+
+            // K * Q
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctx0,
+                        KQ,
+                        ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
+                        );
+
+            // KQ_masked = mask_past(KQ_scaled)
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+
+            // KQ = soft_max(KQ_masked)
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
+            struct ggml_tensor * V_trans =
+                ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+                            n_embd/n_head, n_head, n_past + N),
+                        1, 2, 0, 3);
+
+            // KQV = transpose(V) * KQ_soft_max
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            cur = ggml_cpy(ctx0,
+                    KQV_merged,
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+
+            // projection (no bias)
+            cur = ggml_mul_mat(ctx0,
+                    ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
+                    cur);
+        }
+
+        struct ggml_tensor * inpFF = cur;
+
+        // feed-forward network
+        // this is independent of the self-attention result, so it could be done in parallel to the self-attention
+        {
+            // note here we pass inpSA instead of cur
+            cur = ggml_mul_mat(ctx0,
+                    ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
+                    inpSA);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
+                    cur);
+
+            // GELU activation
+            cur = ggml_gelu(ctx0, cur);
+
+            // projection
+            // cur = proj_w*cur + proj_b
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_mlp_proj_w_trans,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
+                    cur);
+        }
+
+        // self-attention + FF
+        cur  = ggml_add(ctx0, cur, inpFF);
+
+        // input for next layer
+        inpL = ggml_add(ctx0, cur, inpL);
+    }
+
+    // norm
+    {
+        inpL = ggml_norm(ctx0, inpL);
+
+        // inpL = ln_f_g*inpL + ln_f_b
+        inpL = ggml_add(ctx0,
+                ggml_mul(ctx0,
+                    ggml_repeat(ctx0, model.ln_f_g, inpL),
+                    inpL),
+                ggml_repeat(ctx0, model.ln_f_b, inpL));
+    }
+
+    // lm_head
+    {
+        inpL = ggml_mul_mat(ctx0, model.lmh_g, inpL);
+
+        inpL = ggml_add(ctx0,
+                ggml_repeat(ctx0, model.lmh_b, inpL),
+                inpL);
+    }
+
+    // to logits
+    inpL = ggml_soft_max(ctx0, inpL);
+
+    // run the computation
+    ggml_build_forward_expand(&gf, inpL);
+    ggml_graph_compute       (ctx0, &gf);
+
+    //if (n_past%100 == 0) {
+    //    ggml_graph_print   (&gf);
+    //    ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
+    //}
+
+    //embd_w.resize(n_vocab*N);
+    //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+
+    // return result for just the last token
+    embd_w.resize(n_vocab);
+    memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+
+    if (mem_per_token == 0) {
+        mem_per_token = ggml_used_mem(ctx0)/N;
+    }
+    //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
+
+    ggml_free(ctx0);
+
+    return true;
+}
+
+int main(int argc, char ** argv) {
+    const int64_t t_main_start_us = ggml_time_us();
+
+    gpt_params params;
+    params.model = "models/gpt-j-6B/ggml-model.bin";
+
+    if (gpt_params_parse(argc, argv, params) == false) {
+        return 1;
+    }
+
+    if (params.seed < 0) {
+        params.seed = time(NULL);
+    }
+
+    printf("%s: seed = %d\n", __func__, params.seed);
+
+    std::mt19937 rng(params.seed);
+    if (params.prompt.empty()) {
+        params.prompt = gpt_random_prompt(rng);
+    }
+
+    int64_t t_load_us = 0;
+
+    gpt_vocab vocab;
+    gptj_model model;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!gptj_model_load(params.model, model, vocab)) {
+            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+            return 1;
+        }
+
+        t_load_us = ggml_time_us() - t_start_us;
+    }
+
+    int n_past = 0;
+
+    int64_t t_sample_us  = 0;
+    int64_t t_predict_us = 0;
+
+    std::vector<float> embd_w;
+
+    // tokenize the prompt
+    std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
+
+    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
+
+    printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
+    printf("\n");
+
+    std::vector<gpt_vocab::id> embd;
+
+    // determine the required inference memory per token:
+    size_t mem_per_token = 0;
+    gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, embd_w, mem_per_token);
+
+    for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
+        // predict
+        if (embd.size() > 0) {
+            const int64_t t_start_us = ggml_time_us();
+
+            if (!gptj_eval(model, params.n_threads, n_past, embd, embd_w, mem_per_token)) {
+                printf("Failed to predict\n");
+                return 1;
+            }
+
+            t_predict_us += ggml_time_us() - t_start_us;
+        }
+
+        n_past += embd.size();
+        embd.clear();
+
+        if (i >= embd_inp.size()) {
+            // sample next token
+            const int   top_k = params.top_k;
+            const float top_p = params.top_p;
+            const float temp  = params.temp;
+
+            const int n_vocab = model.hparams.n_vocab;
+
+            gpt_vocab::id id = 0;
+
+            {
+                const int64_t t_start_sample_us = ggml_time_us();
+
+                id = gpt_sample_top_k_top_p(vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, rng);
+
+                t_sample_us += ggml_time_us() - t_start_sample_us;
+            }
+
+            // add it to the context
+            embd.push_back(id);
+        } else {
+            // if here, it means we are still processing the input prompt
+            for (int k = i; k < embd_inp.size(); k++) {
+                embd.push_back(embd_inp[k]);
+                if (embd.size() > params.n_batch) {
+                    break;
+                }
+            }
+            i += embd.size() - 1;
+        }
+
+        // display text
+        for (auto id : embd) {
+            printf("%s", vocab.id_to_token[id].c_str());
+        }
+        fflush(stdout);
+
+        // end of text token
+        if (embd.back() == 50256) {
+            break;
+        }
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n\n");
+        printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
+        printf("%s:     load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
+        printf("%s:   sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
+        printf("%s:  predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+    }
+
+    ggml_free(model.ctx);
+
+    return 0;
+}
diff --git a/examples/utils.cpp b/examples/utils.cpp
new file mode 100644 (file)
index 0000000..fbd9ab4
--- /dev/null
@@ -0,0 +1,336 @@
+#include "utils.h"
+
+#include <fstream>
+#include <regex>
+
+bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
+    for (int i = 1; i < argc; i++) {
+        std::string arg = argv[i];
+
+        if (arg == "-s" || arg == "--seed") {
+            params.seed = std::stoi(argv[++i]);
+        } else if (arg == "-t" || arg == "--threads") {
+            params.n_threads = std::stoi(argv[++i]);
+        } else if (arg == "-p" || arg == "--prompt") {
+            params.prompt = argv[++i];
+        } else if (arg == "-n" || arg == "--n_predict") {
+            params.n_predict = std::stoi(argv[++i]);
+        } else if (arg == "--top_k") {
+            params.top_k = std::stoi(argv[++i]);
+        } else if (arg == "--top_p") {
+            params.top_p = std::stof(argv[++i]);
+        } else if (arg == "--temp") {
+            params.temp = std::stof(argv[++i]);
+        } else if (arg == "-b" || arg == "--batch_size") {
+            params.n_batch = std::stoi(argv[++i]);
+        } else if (arg == "-m" || arg == "--model") {
+            params.model = argv[++i];
+        } else if (arg == "-h" || arg == "--help") {
+            gpt_print_usage(argc, argv, params);
+            exit(0);
+        } else {
+            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+            gpt_print_usage(argc, argv, params);
+            exit(0);
+        }
+    }
+
+    return true;
+}
+
+void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
+    fprintf(stderr, "usage: %s [options]\n", argv[0]);
+    fprintf(stderr, "\n");
+    fprintf(stderr, "options:\n");
+    fprintf(stderr, "  -h, --help            show this help message and exit\n");
+    fprintf(stderr, "  -s SEED, --seed SEED  RNG seed (default: -1)\n");
+    fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
+    fprintf(stderr, "  -p PROMPT, --prompt PROMPT\n");
+    fprintf(stderr, "                        prompt to start generation with (default: random)\n");
+    fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d)\n", params.n_predict);
+    fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k);
+    fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p);
+    fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp);
+    fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
+    fprintf(stderr, "  -m FNAME, --model FNAME\n");
+    fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
+    fprintf(stderr, "\n");
+}
+
+void replace(std::string & str, const std::string & needle, const std::string & replacement) {
+    size_t pos = 0;
+    while ((pos = str.find(needle, pos)) != std::string::npos) {
+        str.replace(pos, needle.length(), replacement);
+        pos += replacement.length();
+    }
+}
+
+// poor-man's JSON parsing
+std::map<std::string, int32_t> json_parse(const std::string & fname) {
+    std::map<std::string, int32_t> result;
+
+    // read file into string
+    std::string json;
+    {
+        std::ifstream ifs(fname);
+        if (!ifs) {
+            fprintf(stderr, "Failed to open %s\n", fname.c_str());
+            exit(1);
+        }
+
+        json = std::string((std::istreambuf_iterator<char>(ifs)),
+                (std::istreambuf_iterator<char>()));
+    }
+
+    if (json[0] != '{') {
+        return result;
+    }
+
+    // parse json
+    {
+        bool has_key  = false;
+        bool in_token = false;
+
+        std::string str_key = "";
+        std::string str_val = "";
+
+        int n = json.size();
+        for (int i = 1; i < n; ++i) {
+            if (!in_token) {
+                if (json[i] == ' ') continue;
+                if (json[i] == '"') {
+                    in_token = true;
+                    continue;
+                }
+            } else {
+                if (json[i] == '\\' && i+1 < n) {
+                    if (has_key == false) {
+                        str_key += json[i];
+                    } else {
+                        str_val += json[i];
+                    }
+                    ++i;
+                } else if (json[i] == '"') {
+                    if (has_key == false) {
+                        has_key = true;
+                        ++i;
+                        while (json[i] == ' ') ++i;
+                        ++i; // :
+                        while (json[i] == ' ') ++i;
+                        if (json[i] != '\"') {
+                            while (json[i] != ',' && json[i] != '}') {
+                                str_val += json[i++];
+                            }
+                            has_key = false;
+                        } else {
+                            in_token = true;
+                            continue;
+                        }
+                    } else {
+                        has_key = false;
+                    }
+
+                    ::replace(str_key, "\\u0120", " " ); // \u0120 -> space
+                    ::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
+                    ::replace(str_key, "\\\"",    "\""); // \\\"   -> "
+
+                    try {
+                        result[str_key] = std::stoi(str_val);
+                    } catch (...) {
+                        //fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
+
+                    }
+                    str_key = "";
+                    str_val = "";
+                    in_token = false;
+                    continue;
+                }
+                if (has_key == false) {
+                    str_key += json[i];
+                } else {
+                    str_val += json[i];
+                }
+            }
+        }
+    }
+
+    return result;
+}
+
+std::string gpt_random_prompt(std::mt19937 & rng) {
+    const int r = rng() % 10;
+    switch (r) {
+        case 0: return "So";
+        case 1: return "Once upon a time";
+        case 2: return "When";
+        case 3: return "The";
+        case 4: return "After";
+        case 5: return "If";
+        case 6: return "import";
+        case 7: return "He";
+        case 8: return "She";
+        case 9: return "They";
+        default: return "To";
+    }
+
+    return "The";
+}
+
+std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
+    std::vector<std::string> words;
+
+    // first split the text into words
+    {
+        std::string str = text;
+        std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
+
+        std::regex re(pat);
+        std::smatch m;
+
+        while (std::regex_search(str, m, re)) {
+            for (auto x : m) {
+                words.push_back(x);
+            }
+            str = m.suffix();
+        }
+    }
+
+    // find the longest tokens that form the words:
+    std::vector<gpt_vocab::id> tokens;
+    for (const auto & word : words) {
+        if (word.size() == 0) continue;
+
+        int i = 0;
+        int n = word.size();
+        while (i < n) {
+            int j = n;
+            while (j > i) {
+                auto it = vocab.token_to_id.find(word.substr(i, j-i));
+                if (it != vocab.token_to_id.end()) {
+                    tokens.push_back(it->second);
+                    i = j;
+                    break;
+                }
+                --j;
+            }
+            if (i == n) {
+                break;
+            }
+            if (j == i) {
+                auto sub = word.substr(i, 1);
+                if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
+                    tokens.push_back(vocab.token_to_id.at(sub));
+                } else {
+                    fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
+                }
+                ++i;
+            }
+        }
+    }
+
+    return tokens;
+}
+
+bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
+    printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
+
+    vocab.token_to_id = ::json_parse(fname);
+
+    for (const auto & kv : vocab.token_to_id) {
+        vocab.id_to_token[kv.second] = kv.first;
+    }
+
+    printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
+
+    // print the vocabulary
+    //for (auto kv : vocab.token_to_id) {
+    //    printf("'%s' -> %d\n", kv.first.data(), kv.second);
+    //}
+
+    return true;
+}
+
+gpt_vocab::id gpt_sample_top_k_top_p(
+        const gpt_vocab & vocab,
+        const float * logits,
+        int    top_k,
+        double top_p,
+        double temp,
+        std::mt19937 & rng) {
+    int n_logits = vocab.id_to_token.size();
+
+    std::vector<std::pair<double, gpt_vocab::id>> logits_id;
+    logits_id.reserve(n_logits);
+
+    for (int i = 0; i < n_logits; i++) {
+        logits_id.push_back(std::make_pair(logits[i], i));
+    }
+
+    // find the top K tokens
+    std::partial_sort(
+            logits_id.begin(),
+            logits_id.begin() + top_k, logits_id.end(),
+            [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
+        return a.first > b.first;
+    });
+
+    logits_id.resize(top_k);
+
+    // normalize
+    {
+        double sum = 0.0f;
+        for (int i = 0; i < (int)logits_id.size(); i++) {
+            sum += logits_id[i].first;
+        }
+
+        sum = 1.0/sum;
+        for (int i = 0; i < (int)logits_id.size(); i++) {
+            logits_id[i].first *= sum;
+        }
+    }
+
+    if (top_p < 1.0f) {
+        {
+            double cumsum = 0.0f;
+            for (int i = 0; i < top_k; i++) {
+                cumsum += logits_id[i].first;
+                if (cumsum >= top_p) {
+                    logits_id.resize(i+1);
+                    break;
+                }
+            }
+        }
+
+        // normalize again
+        {
+            double sum = 0.0f;
+            for (int i = 0; i < (int)logits_id.size(); i++) {
+                sum += logits_id[i].first;
+            }
+
+            sum = 1.0/sum;
+            for (int i = 0; i < (int)logits_id.size(); i++) {
+                logits_id[i].first *= sum;
+            }
+        }
+    }
+
+    //printf("\n");
+    //for (int i = 0; i < (int)logits_id.size(); i++) {
+    //    printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
+    //}
+    //exit(0);
+
+    // sample from the obtained distribution
+    std::vector<double> probs;
+    probs.reserve(logits_id.size());
+
+    for (int i = 0; i < (int) logits_id.size(); i++) {
+        probs.push_back(logits_id[i].first);
+    }
+
+    std::discrete_distribution<> dist(probs.begin(), probs.end());
+    int idx = dist(rng);
+
+    return logits_id[idx].second;
+}
diff --git a/examples/utils.h b/examples/utils.h
new file mode 100644 (file)
index 0000000..aee9abf
--- /dev/null
@@ -0,0 +1,84 @@
+// Various helper functions and utilities
+
+#pragma once
+
+#include <string>
+#include <map>
+#include <vector>
+#include <random>
+#include <thread>
+
+//
+// CLI argument parsing
+//
+
+struct gpt_params {
+    int32_t seed      = -1; // RNG seed
+    int32_t n_threads = std::min(8, (int32_t) std::thread::hardware_concurrency());
+    int32_t n_predict = 200; // new tokens to predict
+
+    // sampling parameters
+    int32_t top_k = 40;
+    float   top_p = 0.9f;
+    float   temp  = 1.0f;
+
+    int32_t n_batch = 8; // batch size for prompt processing
+
+    std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
+    std::string prompt;
+};
+
+void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
+
+bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
+
+std::string gpt_random_prompt(std::mt19937 & rng);
+
+//
+// Vocab utils
+//
+
+struct gpt_vocab {
+    using id    = int32_t;
+    using token = std::string;
+
+    std::map<token, id> token_to_id;
+    std::map<id, token> id_to_token;
+};
+
+void replace(std::string & str, const std::string & needle, const std::string & replacement);
+
+// poor-man's JSON parsing
+std::map<std::string, int32_t> json_parse(const std::string & fname);
+
+// split text into tokens
+//
+// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
+//
+// Regex (Python):
+// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
+//
+// Regex (C++):
+// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
+//
+std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
+
+// load the tokens from encoder.json
+bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
+
+// sample next token given probabilities for each embedding
+//
+//   - consider only the top K tokens
+//   - from them, consider only the top tokens with cumulative probability > P
+//
+// TODO: not sure if this implementation is correct
+// TODO: temperature is not implemented
+//
+gpt_vocab::id gpt_sample_top_k_top_p(
+        const gpt_vocab & vocab,
+        const float * logits,
+        int    top_k,
+        double top_p,
+        double temp,
+        std::mt19937 & rng);
+
diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h
new file mode 100644 (file)
index 0000000..04837fb
--- /dev/null
@@ -0,0 +1,511 @@
+#pragma once
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+#include <stdint.h>
+#include <stddef.h>
+#include <stdbool.h>
+
+#define GGML_MAX_DIMS     4
+#define GGML_MAX_NODES    4096
+#define GGML_MAX_PARAMS   16
+#define GGML_MAX_CONTEXTS 16
+
+#ifdef __ARM_NEON
+// we use the built-in 16-bit float type
+typedef __fp16 ggml_fp16_t;
+#else
+typedef uint16_t ggml_fp16_t;
+#endif
+
+float ggml_fp16_to_fp32(ggml_fp16_t x);
+ggml_fp16_t ggml_fp32_to_fp16(float x);
+
+struct ggml_object;
+struct ggml_context;
+
+enum ggml_type {
+    GGML_TYPE_I8,
+    GGML_TYPE_I16,
+    GGML_TYPE_I32,
+    GGML_TYPE_F16,
+    GGML_TYPE_F32,
+    GGML_TYPE_COUNT,
+};
+
+enum ggml_op {
+    GGML_OP_NONE = 0,
+
+    GGML_OP_DUP,
+    GGML_OP_ADD,
+    GGML_OP_SUB,
+    GGML_OP_MUL,
+    GGML_OP_DIV,
+    GGML_OP_SQR,
+    GGML_OP_SQRT,
+    GGML_OP_SUM,
+    GGML_OP_MEAN,
+    GGML_OP_REPEAT,
+    GGML_OP_ABS,
+    GGML_OP_SGN,
+    GGML_OP_NEG,
+    GGML_OP_STEP,
+    GGML_OP_RELU,
+    GGML_OP_GELU,
+    GGML_OP_NORM, // normalize
+
+    GGML_OP_MUL_MAT,
+
+    GGML_OP_SCALE,
+    GGML_OP_CPY,
+    GGML_OP_RESHAPE,
+    GGML_OP_VIEW,
+    GGML_OP_PERMUTE,
+    GGML_OP_TRANSPOSE,
+    GGML_OP_GET_ROWS,
+    GGML_OP_DIAG_MASK_INF,
+    GGML_OP_SOFT_MAX,
+    GGML_OP_ROPE,
+
+    GGML_OP_COUNT,
+};
+
+// n-dimensional tensor
+struct ggml_tensor {
+    enum ggml_type type;
+
+    int    n_dims;
+    int    ne[GGML_MAX_DIMS]; // number of elements
+    size_t nb[GGML_MAX_DIMS]; // stride in bytes:
+                              // nb[0] = sizeof(type)
+                              // nb[1] = nb[0]   * ne[0] + padding
+                              // nb[i] = nb[i-1] * ne[i-1]
+
+    // compute data
+    enum ggml_op op;
+
+    bool is_param;
+
+    struct ggml_tensor * grad;
+    struct ggml_tensor * src0;
+    struct ggml_tensor * src1;
+
+    // thread scheduling
+    int n_tasks;
+
+    // performance
+    int     perf_runs;
+    int64_t perf_cycles;
+    int64_t perf_time_us;
+
+    void * data;
+    char pad[8];
+};
+
+// computation graph
+struct ggml_cgraph {
+    int n_nodes;
+    int n_leafs;
+    int n_threads;
+
+    size_t work_size;
+    struct ggml_tensor * work;
+
+    struct ggml_tensor * nodes[GGML_MAX_NODES];
+    struct ggml_tensor * grads[GGML_MAX_NODES];
+    struct ggml_tensor * leafs[GGML_MAX_NODES];
+
+    // performance
+    int     perf_runs;
+    int64_t perf_cycles;
+    int64_t perf_time_us;
+};
+
+struct ggml_init_params {
+    // memory pool
+    size_t mem_size;   // bytes
+    void * mem_buffer; // if NULL, memory will be allocated internally
+};
+
+int64_t ggml_time_ms(void);
+int64_t ggml_time_us(void);
+int64_t ggml_cycles(void);
+int64_t ggml_cycles_per_ms(void);
+
+void ggml_print_object (const struct ggml_object * obj);
+void ggml_print_objects(const struct ggml_context * ctx);
+
+int    ggml_nelements(const struct ggml_tensor * tensor);
+size_t ggml_nbytes   (const struct ggml_tensor * tensor);
+
+size_t ggml_type_size   (enum ggml_type type);
+size_t ggml_element_size(const struct ggml_tensor * tensor);
+
+struct ggml_context * ggml_init(struct ggml_init_params params);
+void ggml_free(struct ggml_context * ctx);
+
+size_t ggml_used_mem(const struct ggml_context * ctx);
+
+struct ggml_tensor * ggml_new_tensor(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    n_dims,
+        const int *ne);
+
+struct ggml_tensor * ggml_new_tensor_1d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0);
+
+struct ggml_tensor * ggml_new_tensor_2d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0,
+        int    ne1);
+
+struct ggml_tensor * ggml_new_tensor_3d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0,
+        int    ne1,
+        int    ne2);
+
+struct ggml_tensor * ggml_new_tensor_4d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0,
+        int    ne1,
+        int    ne2,
+        int    ne3);
+
+struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
+
+struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
+struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
+
+struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
+struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
+
+float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
+void  ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
+
+ void * ggml_get_data    (const struct ggml_tensor * tensor);
+float * ggml_get_data_f32(const struct ggml_tensor * tensor);
+
+//
+// operations on tensors with backpropagation
+//
+
+struct ggml_tensor * ggml_dup(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_add(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+struct ggml_tensor * ggml_sub(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+struct ggml_tensor * ggml_mul(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+struct ggml_tensor * ggml_div(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+struct ggml_tensor * ggml_sqr(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_sqrt(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// return scalar
+// TODO: compute sum along rows
+struct ggml_tensor * ggml_sum(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// mean along rows
+struct ggml_tensor * ggml_mean(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// if a is the same shape as b, and a is not parameter, return a
+// otherwise, return a new tensor: repeat(a) to fit in b
+struct ggml_tensor * ggml_repeat(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+struct ggml_tensor * ggml_abs(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_sgn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_neg(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_step(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_relu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// TODO: double-check this computation is correct
+struct ggml_tensor * ggml_gelu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// normalize along rows
+// TODO: eps is hardcoded to 1e-5 for now
+struct ggml_tensor * ggml_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// A: m rows, n columns
+// B: p rows, n columns (i.e. we transpose it internally)
+// result is m columns, p rows
+struct ggml_tensor * ggml_mul_mat(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+//
+// operations on tensors without backpropagation
+//
+
+// in-place, returns view(a)
+struct ggml_tensor * ggml_scale(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+// a -> b, return view(b)
+struct ggml_tensor * ggml_cpy(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+// return view(a), b specifies the new shape
+// TODO: when we start computing gradient, make a copy instead of view
+struct ggml_tensor * ggml_reshape(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+// return view(a)
+// TODO: when we start computing gradient, make a copy instead of view
+struct ggml_tensor * ggml_reshape_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1);
+
+// return view(a)
+// TODO: when we start computing gradient, make a copy instead of view
+struct ggml_tensor * ggml_reshape_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1,
+        int                   ne2);
+
+// offset in bytes
+struct ggml_tensor * ggml_view_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        size_t                offset);
+
+struct ggml_tensor * ggml_view_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1,
+        size_t                nb1, // row stride in bytes
+        size_t                offset);
+
+struct ggml_tensor * ggml_permute(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   axis0,
+        int                   axis1,
+        int                   axis2,
+        int                   axis3);
+
+// alias for ggml_permute(ctx, a, 1, 0, 2, 3)
+struct ggml_tensor * ggml_transpose(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_get_rows(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
+// set elements above the diagonal to -INF
+// in-place, returns view(a)
+struct ggml_tensor * ggml_diag_mask_inf(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past);
+
+// in-place, returns view(a)
+struct ggml_tensor * ggml_soft_max(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+// rotary position embedding
+// in-place, returns view(a)
+// if mode == 1, skip n_past elements
+// TODO: avoid creating a new tensor every time
+struct ggml_tensor * ggml_rope(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        int                   mode);
+
+//
+// automatic differentiation
+//
+
+void ggml_set_param(
+        struct ggml_context * ctx,
+        struct ggml_tensor * tensor);
+
+void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
+
+struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
+struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
+
+void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+void ggml_graph_reset  (struct ggml_cgraph * cgraph);
+
+// print info and performance information for the graph
+void ggml_graph_print(const struct ggml_cgraph * cgraph);
+
+// dump the graph into a file using the dot format
+void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
+
+//
+// optimization
+//
+
+// optimization methods
+enum ggml_opt_type {
+    GGML_OPT_ADAM,
+    GGML_OPT_LBFGS,
+};
+
+// linesearch methods
+enum ggml_linesearch {
+    GGML_LINESEARCH_DEFAULT = 1,
+
+    GGML_LINESEARCH_BACKTRACKING_ARMIJO       = 0,
+    GGML_LINESEARCH_BACKTRACKING_WOLFE        = 1,
+    GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
+};
+
+// optimization return values
+enum ggml_opt_result {
+    GGML_OPT_OK = 0,
+    GGML_OPT_DID_NOT_CONVERGE,
+    GGML_OPT_NO_CONTEXT,
+    GGML_OPT_INVALID_WOLFE,
+    GGML_OPT_FAIL,
+
+    GGML_LINESEARCH_FAIL = -128,
+    GGML_LINESEARCH_MINIMUM_STEP,
+    GGML_LINESEARCH_MAXIMUM_STEP,
+    GGML_LINESEARCH_MAXIMUM_ITERATIONS,
+    GGML_LINESEARCH_INVALID_PARAMETERS,
+};
+
+// optimization parameters
+//
+//   see ggml.c (ggml_opt_default_params) for default values
+//
+struct ggml_opt_params {
+    enum ggml_opt_type type;
+
+    int n_threads;
+
+    // delta-based convergence test
+    //
+    //   if past == 0 - disabled
+    //   if past > 0:
+    //     stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
+    //
+    int past;
+    float delta;
+
+    // maximum number of iterations without improvement
+    //
+    //   if 0 - disabled
+    //   if > 0:
+    //     assume convergence if no cost improvement in this number of iterations
+    //
+    int max_no_improvement;
+
+    bool print_forward_graph;
+    bool print_backward_graph;
+
+    union {
+        // ADAM parameters
+        struct {
+            int n_iter;
+
+            float alpha; // learning rate
+            float beta1;
+            float beta2;
+            float eps;   // epsilon for numerical stability
+            float eps_f; // epsilon for convergence test
+            float eps_g; // epsilon for convergence test
+        } adam;
+
+        // LBFGS parameters
+        struct {
+            int m; // number of corrections to approximate the inv. Hessian
+            int n_iter;
+            int max_linesearch;
+
+            float eps;      // convergence tolerance
+            float ftol;     // line search tolerance
+            float wolfe;
+            float min_step;
+            float max_step;
+
+            enum ggml_linesearch linesearch;
+        } lbfgs;
+    };
+};
+
+struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
+
+// optimize the function defined by the tensor f
+enum ggml_opt_result ggml_opt(
+        struct ggml_context * ctx,
+        struct ggml_opt_params params,
+        struct ggml_tensor * f);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
new file mode 100644 (file)
index 0000000..c8a6396
--- /dev/null
@@ -0,0 +1,85 @@
+if (GGML_ALL_WARNINGS)
+    if (CMAKE_COMPILER_IS_GNUCC OR CMAKE_C_COMPILER_ID MATCHES "Clang")
+        #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wextra")
+        set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} \
+            -Wall                           \
+            -Wextra                         \
+            -Wpedantic                      \
+            -Wshadow                        \
+            -Wcast-qual                     \
+            -Wstrict-prototypes             \
+            -Wpointer-arith                 \
+        ")
+    else()
+        # todo : windows
+    endif()
+endif()
+
+# compiler flags
+
+set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror=vla")
+#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fno-math-errno -ffinite-math-only -funsafe-math-optimizations")
+
+message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
+
+if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
+    message(STATUS "ARM detected")
+    #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=apple-m1")
+else()
+    message(STATUS "x86 detected")
+    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c")
+endif()
+
+
+# ggml
+
+set(TARGET ggml)
+
+# on APPLE - include Accelerate framework
+#if (APPLE)
+#    find_library(ACCELERATE_FRAMEWORK Accelerate)
+#    if (ACCELERATE_FRAMEWORK)
+#        message(STATUS "Accelerate framework found")
+#
+#        set(GGML_EXTRA_LIBS  ${GGML_EXTRA_LIBS}  ${ACCELERATE_FRAMEWORK})
+#        set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
+#    else()
+#        message(WARNING "Accelerate framework not found")
+#    endif()
+#endif()
+
+add_library(${TARGET}
+    ggml.c
+    )
+
+target_include_directories(${TARGET} PUBLIC
+    .
+    ../include
+    )
+
+target_link_libraries(${TARGET} PUBLIC m ${GGML_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
+
+if (BUILD_SHARED_LIBS)
+    target_link_libraries(${TARGET} PUBLIC
+        ${CMAKE_DL_LIBS}
+        )
+
+    target_compile_definitions(${TARGET} PUBLIC
+        GGML_SHARED
+        )
+endif()
+
+target_compile_definitions(${TARGET} PUBLIC
+    ${GGML_EXTRA_FLAGS}
+    )
+
+if (MINGW)
+    target_link_libraries(${TARGET} PUBLIC
+        stdc++
+        )
+endif()
+
+install(TARGETS ${TARGET}
+    LIBRARY DESTINATION lib
+    ARCHIVE DESTINATION lib/static
+    )
diff --git a/src/ggml.c b/src/ggml.c
new file mode 100644 (file)
index 0000000..bef3dc5
--- /dev/null
@@ -0,0 +1,5614 @@
+#include "ggml/ggml.h"
+
+#include <assert.h>
+#include <time.h>
+#include <math.h>
+#include <stdlib.h>
+#include <string.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdatomic.h>
+
+#include <pthread.h>
+
+#define GGML_DEBUG 0
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+
+#define GGML_MEM_ALIGN 16
+
+#define UNUSED(x) (void)(x)
+#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
+
+// floating point type used to accumulate sums
+typedef double ggml_float;
+
+// 16-bit float
+// on Arm, we use __fp16
+// on x86, we use uint16_t
+#ifdef __ARM_NEON
+
+// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
+//
+//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
+//
+#include <arm_neon.h>
+
+float ggml_fp16_to_fp32(ggml_fp16_t x) {
+    return x;
+}
+
+ggml_fp16_t ggml_fp32_to_fp16(float x) {
+    return x;
+}
+
+#else
+
+#include <immintrin.h>
+
+static inline float fp32_from_bits(uint32_t w) {
+    union {
+        uint32_t as_bits;
+        float as_value;
+    } fp32 = { w };
+    return fp32.as_value;
+}
+
+static inline uint32_t fp32_to_bits(float f) {
+       union {
+               float as_value;
+               uint32_t as_bits;
+       } fp32 = { f };
+       return fp32.as_bits;
+}
+
+float ggml_fp16_to_fp32(ggml_fp16_t h) {
+    const uint32_t w = (uint32_t) h << 16;
+    const uint32_t sign = w & UINT32_C(0x80000000);
+    const uint32_t two_w = w + w;
+
+    const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+    const float exp_scale = 0x1.0p-112f;
+#else
+    const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
+#endif
+    const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+    const uint32_t magic_mask = UINT32_C(126) << 23;
+    const float magic_bias = 0.5f;
+    const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+    const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+    const uint32_t result = sign |
+        (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
+    return fp32_from_bits(result);
+}
+
+ggml_fp16_t ggml_fp32_to_fp16(float f) {
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+    const float scale_to_inf = 0x1.0p+112f;
+    const float scale_to_zero = 0x1.0p-110f;
+#else
+    const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
+    const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
+#endif
+    float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+    const uint32_t w = fp32_to_bits(f);
+    const uint32_t shl1_w = w + w;
+    const uint32_t sign = w & UINT32_C(0x80000000);
+    uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+    if (bias < UINT32_C(0x71000000)) {
+        bias = UINT32_C(0x71000000);
+    }
+
+    base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+    const uint32_t bits = fp32_to_bits(base);
+    const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+    const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+    const uint32_t nonsign = exp_bits + mantissa_bits;
+    return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
+}
+#endif
+
+//
+// timing
+//
+
+// TODO: need to be able to disable these in performance critical code since they make slow system calls
+int64_t ggml_time_ms(void) {
+    struct timespec ts;
+    clock_gettime(CLOCK_MONOTONIC, &ts);
+    return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000;
+}
+
+int64_t ggml_time_us(void) {
+    struct timespec ts;
+    clock_gettime(CLOCK_MONOTONIC, &ts);
+    return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000;
+}
+
+int64_t ggml_cycles(void) {
+    return clock();
+}
+
+int64_t ggml_cycles_per_ms(void) {
+    return CLOCKS_PER_SEC/1000;
+}
+
+//
+// cache line
+//
+
+#if defined(__cpp_lib_hardware_interference_size)
+       const size_t CACHE_LINE_SIZE = hardware_destructive_interference_size;
+#else
+       const size_t CACHE_LINE_SIZE = 64;
+#endif
+
+const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
+
+//
+// fundamental operations
+//
+
+inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i]  = v; }
+
+inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i]  = v; }
+
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
+inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
+inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
+inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
+inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
+inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
+inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
+inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
+inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
+
+inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
+    for (int i = 0; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+}
+
+inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
+    ggml_float sum = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sum += x[i]*y[i];
+    }
+    *s = sum;
+}
+
+inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
+    ggml_float sumf = 0.0;
+#ifdef __ARM_NEON
+    const int n64 = 64*(n/64);
+
+    float16x8_t sum0 = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
+    float16x8_t sum1 = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
+    float16x8_t sum2 = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
+    float16x8_t sum3 = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
+    float16x8_t sum4 = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
+    float16x8_t sum5 = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
+    float16x8_t sum6 = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
+    float16x8_t sum7 = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
+
+    float16x8_t x0, x1, x2, x3, x4, x5, x6, x7;
+    float16x8_t y0, y1, y2, y3, y4, y5, y6, y7;
+
+    for (int i = 0; i < n64; i += 64) {
+        x0 = vld1q_f16(x + i + 0 );
+        x1 = vld1q_f16(x + i + 8 );
+        x2 = vld1q_f16(x + i + 16);
+        x3 = vld1q_f16(x + i + 24);
+        x4 = vld1q_f16(x + i + 32);
+        x5 = vld1q_f16(x + i + 40);
+        x6 = vld1q_f16(x + i + 48);
+        x7 = vld1q_f16(x + i + 56);
+
+        y0 = vld1q_f16(y + i + 0 );
+        y1 = vld1q_f16(y + i + 8 );
+        y2 = vld1q_f16(y + i + 16);
+        y3 = vld1q_f16(y + i + 24);
+        y4 = vld1q_f16(y + i + 32);
+        y5 = vld1q_f16(y + i + 40);
+        y6 = vld1q_f16(y + i + 48);
+        y7 = vld1q_f16(y + i + 56);
+
+        sum0 = vfmaq_f16(sum0, x0, y0);
+        sum1 = vfmaq_f16(sum1, x1, y1);
+        sum2 = vfmaq_f16(sum2, x2, y2);
+        sum3 = vfmaq_f16(sum3, x3, y3);
+        sum4 = vfmaq_f16(sum4, x4, y4);
+        sum5 = vfmaq_f16(sum5, x5, y5);
+        sum6 = vfmaq_f16(sum6, x6, y6);
+        sum7 = vfmaq_f16(sum7, x7, y7);
+    }
+
+    // TODO: F16 - better way to reduce this ?
+    float16x8_t sum = vaddq_f16(sum0, sum1);
+
+    sum = vaddq_f16(sum, sum2);
+    sum = vaddq_f16(sum, sum3);
+    sum = vaddq_f16(sum, sum4);
+    sum = vaddq_f16(sum, sum5);
+    sum = vaddq_f16(sum, sum6);
+    sum = vaddq_f16(sum, sum7);
+
+    sumf += sum[0] + sum[1] + sum[2] + sum[3] + sum[4] + sum[5] + sum[6] + sum[7];
+
+    // I think this somehow makes the inference worse .. not sure ?
+    //sum0 = vaddq_f16(sum0, sum1);
+    //sum2 = vaddq_f16(sum2, sum3);
+    //sum4 = vaddq_f16(sum4, sum5);
+    //sum6 = vaddq_f16(sum6, sum7);
+
+    //sum0 = vaddq_f16(sum0, sum2);
+    //sum4 = vaddq_f16(sum4, sum6);
+
+    //sum0 = vaddq_f16(sum0, sum4);
+
+    //for (int i = 0; i < 8; ++i) {
+    //    sumf += sum0[i];
+    //}
+
+    // leftovers
+    for (int i = n64; i < n; ++i) {
+        sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
+    }
+#else
+    // AVX 256-bit (unroll 4)
+    const int n32 = 32*(n/32);
+
+    __m256 sum0 = _mm256_setzero_ps();
+    __m256 sum1 = _mm256_setzero_ps();
+    __m256 sum2 = _mm256_setzero_ps();
+    __m256 sum3 = _mm256_setzero_ps();
+
+    __m256 x0, x1, x2, x3;
+    __m256 y0, y1, y2, y3;
+
+    for (int i = 0; i < n32; i += 32) {
+        x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
+        x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
+        x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
+        x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
+
+        y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
+        y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
+        y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
+        y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
+
+        sum0 = _mm256_fmadd_ps(x0, y0, sum0);
+        sum1 = _mm256_fmadd_ps(x1, y1, sum1);
+        sum2 = _mm256_fmadd_ps(x2, y2, sum2);
+        sum3 = _mm256_fmadd_ps(x3, y3, sum3);
+    }
+
+    const __m256 sum01 = _mm256_add_ps(sum0, sum1);
+    const __m256 sum23 = _mm256_add_ps(sum2, sum3);
+    const __m256 sum0123 = _mm256_add_ps(sum01, sum23);
+
+    const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1));
+    const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4));
+    const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
+
+    sumf = _mm_cvtss_f32(r1);
+
+    // leftovers
+    for (int i = n32; i < n; ++i) {
+        sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
+    }
+#endif
+
+    *s = sumf;
+}
+
+inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) {
+#ifdef __ARM_NEON
+    // NEON 128-bit
+    const int n64 = 64*(n/64);
+
+    const float16x8_t v8 = vdupq_n_f16(v);
+
+    float16x8_t x0, x1, x2, x3, x4, x5, x6, x7;
+    float16x8_t y0, y1, y2, y3, y4, y5, y6, y7;
+
+    for (int i = 0; i < n64; i += 64) {
+        y0 = vld1q_f16(y + i + 0 );
+        y1 = vld1q_f16(y + i + 8 );
+        y2 = vld1q_f16(y + i + 16);
+        y3 = vld1q_f16(y + i + 24);
+        y4 = vld1q_f16(y + i + 32);
+        y5 = vld1q_f16(y + i + 40);
+        y6 = vld1q_f16(y + i + 48);
+        y7 = vld1q_f16(y + i + 56);
+
+        x0 = vld1q_f16(x + i + 0 );
+        x1 = vld1q_f16(x + i + 8 );
+        x2 = vld1q_f16(x + i + 16);
+        x3 = vld1q_f16(x + i + 24);
+        x4 = vld1q_f16(x + i + 32);
+        x5 = vld1q_f16(x + i + 40);
+        x6 = vld1q_f16(x + i + 48);
+        x7 = vld1q_f16(x + i + 56);
+
+        y0 = vfmaq_f16(y0, x0, v8);
+        y1 = vfmaq_f16(y1, x1, v8);
+        y2 = vfmaq_f16(y2, x2, v8);
+        y3 = vfmaq_f16(y3, x3, v8);
+        y4 = vfmaq_f16(y4, x4, v8);
+        y5 = vfmaq_f16(y5, x5, v8);
+        y6 = vfmaq_f16(y6, x6, v8);
+        y7 = vfmaq_f16(y7, x7, v8);
+
+        vst1q_f16(y + i + 0 , y0);
+        vst1q_f16(y + i + 8 , y1);
+        vst1q_f16(y + i + 16, y2);
+        vst1q_f16(y + i + 24, y3);
+        vst1q_f16(y + i + 32, y4);
+        vst1q_f16(y + i + 40, y5);
+        vst1q_f16(y + i + 48, y6);
+        vst1q_f16(y + i + 56, y7);
+    }
+
+    // leftovers
+    for (int i = n64; i < n; ++i) {
+        y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+    }
+#else
+    // AVX 256-bit
+    const int n32 = 32*(n/32);
+
+    const __m256 v8 = _mm256_set1_ps(v);
+
+    __m256 x0, x1, x2, x3;
+    __m256 y0, y1, y2, y3;
+
+    for (int i = 0; i < n32; i += 32) {
+        y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
+        y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
+        y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
+        y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
+
+        x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
+        x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
+        x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
+        x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
+
+        y0 = _mm256_fmadd_ps(x0, v8, y0);
+        y1 = _mm256_fmadd_ps(x1, v8, y1);
+        y2 = _mm256_fmadd_ps(x2, v8, y2);
+        y3 = _mm256_fmadd_ps(x3, v8, y3);
+
+        _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
+        _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
+        _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0));
+        _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0));
+    }
+
+    // leftovers
+    for (int i = n32; i < n; ++i) {
+        y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
+    }
+#endif
+}
+
+
+inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s);   }
+inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
+inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
+inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
+inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
+inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
+inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+
+const ggml_float GELU_COEF_A    = 0.044715;
+const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
+
+inline static void ggml_vec_gelu_f32 (const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        //y[i] = 0.5f*x[i]*(1.f + tanhf(SQRT_2_OVER_PI*(x[i] + 0.044715f*x[i]*x[i]*x[i])));
+        //0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))
+        const ggml_float xx = x[i];
+        y[i] = 0.5*xx*(1.0 + tanh(SQRT_2_OVER_PI*xx*(1.0 + GELU_COEF_A*xx*xx)));
+    }
+}
+
+inline static void ggml_vec_sum_f32     (const int n, float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) sum += x[i]; *s += sum; }
+inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }
+
+//
+// logging
+//
+
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) printf(__VA_ARGS__)
+
+//
+// data types
+//
+
+const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
+    sizeof(int8_t ),
+    sizeof(int16_t),
+    sizeof(int32_t),
+    sizeof(ggml_fp16_t),
+    sizeof(float  ),
+};
+
+const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
+    "NONE",
+
+    "DUP",
+    "ADD",
+    "SUB",
+    "MUL",
+    "DIV",
+    "SQR",
+    "SQRT",
+    "SUM",
+    "MEAN",
+    "REPEAT",
+    "ABS",
+    "SGN",
+    "NEG",
+    "STEP",
+    "RELU",
+    "GELU",
+    "NORM",
+
+    "MUL_MAT",
+
+    "SCALE",
+    "CPY",
+    "RESHAPE",
+    "VIEW",
+    "PERMUTE",
+    "TRANSPOSE",
+    "GET_ROWS",
+    "DIAG_MASK_INF",
+    "SOFT_MAX",
+    "ROPE",
+};
+
+const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+    "none",
+
+    "x",
+    "x+y",
+    "x-y",
+    "x*y",
+    "x/y",
+    "x^2",
+    "√x",
+    "Σx",
+    "Σx/n",
+    "repeat(x)",
+    "abs(x)",
+    "sgn(x)",
+    "-x",
+    "step(x)",
+    "relu(x)",
+    "gelu(x)",
+    "norm(x)",
+
+    "X*Y",
+
+    "x*v",
+    "x-\\>y",
+    "reshape(x)",
+    "view(x)",
+    "permute(x)",
+    "transpose(x)",
+    "get_rows(x)",
+    "diag_mask_inf(x)",
+    "soft_max(x)",
+    "rope(x)",
+};
+
+//
+// ggml object
+//
+
+struct ggml_object {
+    size_t offset;
+    size_t size;
+
+    struct ggml_object * next;
+
+    char padding[8];
+};
+
+const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
+
+static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
+static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
+
+//
+// ggml context
+//
+
+struct ggml_context {
+    size_t mem_size;
+    void * mem_buffer;
+    bool   mem_buffer_owned;
+
+    int n_objects;
+
+    struct ggml_object * objects_begin;
+    struct ggml_object * objects_end;
+};
+
+struct ggml_context_container {
+    bool used;
+
+    struct ggml_context context;
+};
+
+//
+// compute types
+//
+
+enum ggml_task_type {
+    GGML_TASK_INIT = 0,
+    GGML_TASK_COMPUTE,
+    GGML_TASK_FINALIZE,
+};
+
+struct ggml_compute_params {
+    enum ggml_task_type type;
+
+    int ith, nth;
+
+    // work buffer for all threads
+    size_t wsize;
+    void * wdata;
+};
+
+//
+// ggml state
+//
+
+struct ggml_state {
+    struct ggml_context_container contexts[GGML_MAX_CONTEXTS];
+};
+
+// global state
+struct ggml_state g_state;
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_print_object(const struct ggml_object * obj) {
+    GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n",
+            obj->offset, obj->size, (const void *) obj->next);
+}
+
+void ggml_print_objects(const struct ggml_context * ctx) {
+    struct ggml_object * obj = ctx->objects_begin;
+
+    GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx);
+
+    while (obj != NULL) {
+        ggml_print_object(obj);
+        obj = obj->next;
+    }
+
+    GGML_PRINT("%s: --- end ---\n", __func__);
+}
+
+int ggml_nelements(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
+}
+
+int ggml_nrows(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
+}
+
+size_t ggml_nbytes(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type];
+}
+
+size_t ggml_type_size(enum ggml_type type) {
+    return GGML_TYPE_SIZE[type];
+}
+
+size_t ggml_element_size(const struct ggml_tensor * tensor) {
+    return GGML_TYPE_SIZE[tensor->type];
+}
+
+bool ggml_is_scalar(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_is_vector(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_is_matrix(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        (t0->ne[0]  == t1->ne[0])  &&
+        (t0->ne[2]  == t1->ne[2])  &&
+        (t0->ne[3]  == t1->ne[3]);
+}
+
+bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
+        tensor->nb[1] == tensor->nb[0]*tensor->ne[0] &&
+        tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
+bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
+        tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];;
+}
+
+bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        (t0->ne[0] == t1->ne[0] ) &&
+        (t0->ne[1] == t1->ne[1] ) &&
+        (t0->ne[2] == t1->ne[2] ) &&
+        (t0->ne[3] == t1->ne[3] );
+}
+
+// check if t1 can be represented as a repeatition of t0
+bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        (t1->ne[0]%t0->ne[0] == 0) &&
+        (t1->ne[1]%t0->ne[1] == 0) &&
+        (t1->ne[2]%t0->ne[2] == 0) &&
+        (t1->ne[3]%t0->ne[3] == 0);
+}
+
+// assert that pointer is aligned to GGML_MEM_ALIGN
+#define ggml_assert_aligned(ptr) \
+    assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct ggml_context * ggml_init(struct ggml_init_params params) {
+    // find non-used context in g_state
+    struct ggml_context * ctx = NULL;
+
+    static bool first_time = true;
+    if (first_time) {
+        for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
+            g_state.contexts[i].used = false;
+        }
+        first_time = false;
+    }
+
+    for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
+        if (!g_state.contexts[i].used) {
+            g_state.contexts[i].used = true;
+            ctx = &g_state.contexts[i].context;
+
+            GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i);
+            break;
+        }
+    }
+
+    if (ctx == NULL) {
+        GGML_PRINT_DEBUG("%s\n", "ggml_init: no unused context found");
+        return NULL;
+    }
+
+    *ctx = (struct ggml_context) {
+        .mem_size         = params.mem_size,
+        .mem_buffer       = params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
+        .mem_buffer_owned = params.mem_buffer ? false : true,
+        .n_objects        = 0,
+        .objects_begin    = NULL,
+        .objects_end      = NULL,
+    };
+
+    ggml_assert_aligned(ctx->mem_buffer);
+
+    return ctx;
+}
+
+void ggml_free(struct ggml_context * ctx) {
+    for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
+        if (&g_state.contexts[i].context == ctx) {
+            g_state.contexts[i].used = false;
+
+            GGML_PRINT_DEBUG("ggml_free: context %d with %d objects has been freed. memory used = %zu\n",
+                    i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size);
+
+            if (ctx->mem_buffer_owned) {
+                free(ctx->mem_buffer);
+            }
+
+            return;
+        }
+    }
+
+    GGML_PRINT_DEBUG("%s: context not found\n", __func__);
+}
+
+size_t ggml_used_mem(const struct ggml_context * ctx) {
+    return ctx->objects_end->offset + ctx->objects_end->size;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct ggml_tensor * ggml_new_tensor_impl(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    n_dims,
+        const int* ne,
+        void*  data) {
+    // always insert objects at the end of the context's memory pool
+    struct ggml_object * obj_cur = ctx->objects_end;
+
+    const size_t cur_offset = obj_cur == NULL ? 0 : obj_cur->offset;
+    const size_t cur_size   = obj_cur == NULL ? 0 : obj_cur->size;
+    const size_t cur_end    = cur_offset + cur_size;
+
+    size_t size_needed = 0;
+
+    if (data == NULL) {
+        size_needed += GGML_TYPE_SIZE[type];
+        for (int i = 0; i < n_dims; i++) {
+            size_needed *= ne[i];
+        }
+        // align to GGML_MEM_ALIGN
+        size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN;
+
+    }
+    size_needed += sizeof(struct ggml_tensor);
+
+    if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
+        GGML_PRINT("%s: not enough space in the context's memory pool\n", __func__);
+        assert(false);
+        return NULL;
+    }
+
+    char * const mem_buffer = ctx->mem_buffer;
+
+    struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
+
+    *obj_new = (struct ggml_object) {
+        .offset = cur_end + GGML_OBJECT_SIZE,
+        .size   = size_needed,
+        .next   = NULL,
+    };
+
+    if (obj_cur != NULL) {
+        obj_cur->next = obj_new;
+    } else {
+        // this is the first object in this context
+        ctx->objects_begin = obj_new;
+    }
+
+    ctx->objects_end = obj_new;
+
+    //GGML_PRINT_DEBUG("%s: inserted new object at %zu\n", __func__, cur_end);
+
+    struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offset);
+
+    ggml_assert_aligned(result);
+
+    *result = (struct ggml_tensor) {
+        /*.type         =*/ type,
+        /*.n_dims       =*/ n_dims,
+        /*.ne           =*/ { 1, 1, 1, 1 },
+        /*.nb           =*/ { 0, 0, 0, 0 },
+        /*.op           =*/ GGML_OP_NONE,
+        /*.is_param     =*/ false,
+        /*.grad         =*/ NULL,
+        /*.src0         =*/ NULL,
+        /*.src1         =*/ NULL,
+        /*.n_tasks      =*/ 0,
+        /*.perf_runs    =*/ 0,
+        /*.perf_cycles  =*/ 0,
+        /*.perf_time_us =*/ 0,
+        /*.data         =*/ data == NULL ? (void *)(result + 1) : data,
+        /*.pad          =*/ { 0 },
+    };
+
+    ggml_assert_aligned(result->data);
+
+    for (int i = 0; i < n_dims; i++) {
+        result->ne[i] = ne[i];
+    }
+
+    result->nb[0] = GGML_TYPE_SIZE[type];
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
+    }
+
+    ctx->n_objects++;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_new_tensor(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    n_dims,
+        const int* ne) {
+    return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
+}
+
+struct ggml_tensor * ggml_new_tensor_1d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0) {
+    return ggml_new_tensor(ctx, type, 1, &ne0);
+}
+
+struct ggml_tensor * ggml_new_tensor_2d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0,
+        int    ne1) {
+    const int ne[2] = { ne0, ne1 };
+    return ggml_new_tensor(ctx, type, 2, ne);
+}
+
+struct ggml_tensor * ggml_new_tensor_3d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0,
+        int    ne1,
+        int    ne2) {
+    const int ne[3] = { ne0, ne1, ne2 };
+    return ggml_new_tensor(ctx, type, 3, ne);
+}
+
+struct ggml_tensor * ggml_new_tensor_4d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    ne0,
+        int    ne1,
+        int    ne2,
+        int    ne3) {
+    const int ne[4] = { ne0, ne1, ne2, ne3 };
+    return ggml_new_tensor(ctx, type, 4, ne);
+}
+
+struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+
+    ggml_set_f32(result, value);
+
+    return result;
+}
+
+struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) {
+    return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL);
+}
+
+struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
+    memset(tensor->data, 0, ggml_nbytes(tensor));
+    return tensor;
+}
+
+struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
+    const int n     = ggml_nrows(tensor);
+    const int nc    = tensor->ne[0];
+    const size_t n1 = tensor->nb[1];
+
+    char * const data = tensor->data;
+
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                assert(tensor->nb[0] == sizeof(int8_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I16:
+            {
+                assert(tensor->nb[0] == sizeof(int16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I32:
+            {
+                assert(tensor->nb[0] == sizeof(int32_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_F16:
+            {
+                assert(false); // TODO: implement
+            } break;
+        case GGML_TYPE_F32:
+            {
+                assert(tensor->nb[0] == sizeof(float));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+
+    return tensor;
+}
+
+float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                assert(tensor->nb[0] == sizeof(int8_t));
+                return ((int8_t *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_I16:
+            {
+                assert(tensor->nb[0] == sizeof(int16_t));
+                return ((int16_t *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_I32:
+            {
+                assert(tensor->nb[0] == sizeof(int32_t));
+                return ((int32_t *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_F16:
+            {
+                assert(false); // TODO: implement
+            } break;
+        case GGML_TYPE_F32:
+            {
+                assert(tensor->nb[0] == sizeof(float));
+                return ((float *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+
+    assert(false);
+    return 0.0f;
+}
+
+void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                assert(tensor->nb[0] == sizeof(int8_t));
+                ((int8_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                assert(tensor->nb[0] == sizeof(int16_t));
+                ((int16_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                assert(tensor->nb[0] == sizeof(int32_t));
+                ((int32_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                assert(false); // TODO: implement
+            } break;
+        case GGML_TYPE_F32:
+            {
+                assert(tensor->nb[0] == sizeof(float));
+                ((float *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+void * ggml_get_data(const struct ggml_tensor * tensor) {
+    return tensor->data;
+}
+
+float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
+    assert(tensor->type == GGML_TYPE_F32);
+    return (float *)(tensor->data);
+}
+
+struct ggml_tensor * ggml_view_tensor(
+        struct ggml_context * ctx,
+        const struct ggml_tensor * src) {
+    return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// ggml_dup
+
+struct ggml_tensor * ggml_dup_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_DUP;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_dup(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    return ggml_dup_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_dup_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    return ggml_dup_impl(ctx, a, true);
+}
+
+// ggml_add
+
+struct ggml_tensor * ggml_add_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        bool inplace) {
+    assert(ggml_are_same_shape(a, b));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_ADD;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_add(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_add_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_add_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_add_impl(ctx, a, b, true);
+}
+
+// ggml_sub
+
+struct ggml_tensor * ggml_sub_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        bool inplace) {
+    assert(ggml_are_same_shape(a, b));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SUB;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sub(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_sub_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_sub_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_sub_impl(ctx, a, b, true);
+}
+
+// ggml_mul
+
+struct ggml_tensor * ggml_mul_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        bool inplace) {
+    assert(ggml_are_same_shape(a, b));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
+    }
+
+    if (inplace) {
+        assert(is_node == false);
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_MUL;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_mul(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_mul_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_mul_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_mul_impl(ctx, a, b, true);
+}
+
+// ggml_div
+
+struct ggml_tensor * ggml_div_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        bool inplace) {
+    assert(ggml_are_same_shape(a, b));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
+    }
+
+    if (inplace) {
+        assert(is_node == false);
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_DIV;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_div(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_div_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_div_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_div_impl(ctx, a, b, true);
+}
+
+// ggml_sqr
+
+struct ggml_tensor * ggml_sqr_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SQR;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sqr(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sqr_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sqr_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sqr_impl(ctx, a, true);
+}
+
+// ggml_sqrt
+
+struct ggml_tensor * ggml_sqrt_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SQRT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sqrt(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sqrt_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sqrt_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sqrt_impl(ctx, a, true);
+}
+
+// ggml_sum
+
+struct ggml_tensor * ggml_sum(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
+
+    result->op   = GGML_OP_SUM;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_mean
+
+struct ggml_tensor * ggml_mean(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    bool is_node = false;
+
+    if (a->grad) {
+        assert(false); // TODO: implement
+        is_node = true;
+    }
+
+    int ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne);
+
+    result->op   = GGML_OP_MEAN;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_repeat
+
+struct ggml_tensor * ggml_repeat(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    assert(ggml_can_repeat(a, b));
+
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    if (ggml_are_same_shape(a, b) && !is_node) {
+        return a;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
+
+    result->op   = GGML_OP_REPEAT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_abs
+
+struct ggml_tensor * ggml_abs_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_ABS;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_abs(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_abs_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_abs_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_abs_impl(ctx, a, true);
+}
+
+
+// ggml_sgn
+
+struct ggml_tensor * ggml_sgn_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SGN;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sgn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sgn_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sgn_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sgn_impl(ctx, a, true);
+}
+
+// ggml_neg
+
+struct ggml_tensor * ggml_neg_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_NEG;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_neg(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_neg_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_neg_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_neg_impl(ctx, a, true);
+}
+
+// ggml_step
+
+struct ggml_tensor * ggml_step_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_STEP;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_step(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_step_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_step_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_step_impl(ctx, a, true);
+}
+
+// ggml_relu
+
+struct ggml_tensor * ggml_relu_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_RELU;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_relu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_relu_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_relu_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_relu_impl(ctx, a, true);
+}
+
+// ggml_gelu
+
+struct ggml_tensor * ggml_gelu_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_GELU;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_gelu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_gelu_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_gelu_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_gelu_impl(ctx, a, true);
+}
+
+// ggml_norm
+
+struct ggml_tensor * ggml_norm_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        assert(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_NORM;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store epsilon here?
+
+    return result;
+}
+
+struct ggml_tensor * ggml_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_norm_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_norm_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_norm_impl(ctx, a, true);
+}
+
+// ggml_mul_mat
+
+struct ggml_tensor * ggml_mul_mat(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    assert(ggml_can_mul_mat(a, b));
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        is_node = true;
+    }
+
+    const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
+
+    result->op   = GGML_OP_MUL_MAT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_scale
+
+struct ggml_tensor * ggml_scale_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool inplace) {
+    assert(ggml_is_scalar(b));
+    assert(ggml_is_padded_1d(a));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        assert(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: when implement backward, fix this:
+    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    result->op   = GGML_OP_SCALE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_scale(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_scale_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_scale_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_scale_impl(ctx, a, b, true);
+}
+
+// ggml_cpy
+
+struct ggml_tensor * ggml_cpy_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool inplace) {
+    assert(ggml_nelements(a) == ggml_nelements(b));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        assert(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // make a view of the destination
+    struct ggml_tensor * result = ggml_view_tensor(ctx, b);
+
+    result->op   = GGML_OP_CPY;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_cpy(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_cpy_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_cpy_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_cpy_impl(ctx, a, b, true);
+}
+
+// ggml_reshape
+
+struct ggml_tensor * ggml_reshape(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    assert(ggml_is_contiguous(a));
+    assert(ggml_is_contiguous(b));
+    assert(ggml_nelements(a) == ggml_nelements(b));
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        assert(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
+
+    result->op   = GGML_OP_RESHAPE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_reshape_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1) {
+    assert(ggml_is_contiguous(a));
+    assert(ggml_nelements(a) == ne0*ne1);
+
+    bool is_node = false;
+
+    if (a->grad) {
+        assert(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    const int ne[2] = { ne0, ne1 };
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data);
+
+    result->op   = GGML_OP_RESHAPE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_reshape_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1,
+        int                   ne2) {
+    assert(ggml_is_contiguous(a));
+    assert(ggml_nelements(a) == ne0*ne1*ne2);
+
+    bool is_node = false;
+
+    if (a->grad) {
+        assert(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    const int ne[3] = { ne0, ne1, ne2 };
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data);
+
+    result->op   = GGML_OP_RESHAPE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_view_1d
+
+struct ggml_tensor * ggml_view_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        size_t                offset) {
+    if (a->grad) {
+        assert(false); // gradient propagation is not supported
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
+
+    result->op   = GGML_OP_VIEW;
+    result->grad = NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store the offset here?
+
+    return result;
+}
+
+// ggml_view_2d
+
+struct ggml_tensor * ggml_view_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1,
+        size_t                nb1,
+        size_t                offset) {
+    if (a->grad) {
+        assert(false); // gradient propagation is not supported
+    }
+
+    const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
+
+    result->nb[1] = nb1;
+    result->nb[2] = result->nb[1]*ne1;
+    result->nb[3] = result->nb[2];
+
+    result->op   = GGML_OP_VIEW;
+    result->grad = NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store the offset here?
+
+    return result;
+}
+
+// ggml_permute
+
+struct ggml_tensor * ggml_permute(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   axis0,
+        int                   axis1,
+        int                   axis2,
+        int                   axis3) {
+    assert(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
+    assert(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
+    assert(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
+    assert(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
+
+    assert(axis0 != axis1);
+    assert(axis0 != axis2);
+    assert(axis0 != axis3);
+    assert(axis1 != axis2);
+    assert(axis1 != axis3);
+    assert(axis2 != axis3);
+
+    bool is_node = false;
+
+    if (a->grad) {
+        assert(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    int ne[GGML_MAX_DIMS];
+    int nb[GGML_MAX_DIMS];
+
+    ne[axis0] = a->ne[0];
+    ne[axis1] = a->ne[1];
+    ne[axis2] = a->ne[2];
+    ne[axis3] = a->ne[3];
+
+    nb[axis0] = a->nb[0];
+    nb[axis1] = a->nb[1];
+    nb[axis2] = a->nb[2];
+    nb[axis3] = a->nb[3];
+
+    result->ne[0] = ne[0];
+    result->ne[1] = ne[1];
+    result->ne[2] = ne[2];
+    result->ne[3] = ne[3];
+
+    result->nb[0] = nb[0];
+    result->nb[1] = nb[1];
+    result->nb[2] = nb[2];
+    result->nb[3] = nb[3];
+
+    result->op   = GGML_OP_PERMUTE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store the permutation here?
+
+    return result;
+}
+
+// ggml_transpose
+
+struct ggml_tensor * ggml_transpose(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    bool is_node = false;
+
+    if (a->grad) {
+        assert(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    result->ne[0] = a->ne[1];
+    result->ne[1] = a->ne[0];
+
+    result->nb[0] = a->nb[1];
+    result->nb[1] = a->nb[0];
+
+    result->op   = GGML_OP_TRANSPOSE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_get_rows
+
+struct ggml_tensor * ggml_get_rows(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    assert(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        assert(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: implement non F32 return
+    //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
+    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
+
+    result->op   = GGML_OP_GET_ROWS;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_diag_mask_inf
+
+struct ggml_tensor * ggml_diag_mask_inf(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past) {
+    bool is_node = false;
+
+    if (a->grad) {
+        assert(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: when implement backward, fix this:
+    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
+    ((int32_t *) b->data)[0] = n_past;
+
+    result->op   = GGML_OP_DIAG_MASK_INF;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_soft_max
+
+struct ggml_tensor * ggml_soft_max(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    bool is_node = false;
+
+    if (a->grad) {
+        assert(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: when implement backward, fix this:
+    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    result->op   = GGML_OP_SOFT_MAX;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_rope
+
+struct ggml_tensor * ggml_rope(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        int                   mode) {
+    assert(n_past >= 0);
+    bool is_node = false;
+
+    if (a->grad) {
+        assert(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: when implement backward, fix this:
+    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+    ((int32_t *) b->data)[0] = n_past;
+    ((int32_t *) b->data)[1] = n_dims;
+    ((int32_t *) b->data)[2] = mode;
+
+    result->op   = GGML_OP_ROPE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_set_param(
+        struct ggml_context * ctx,
+        struct ggml_tensor * tensor) {
+    tensor->is_param = true;
+
+    assert(tensor->grad == NULL);
+    tensor->grad = ggml_dup_tensor(ctx, tensor);
+}
+
+// ggml_compute_forward_dup
+
+void ggml_compute_forward_dup(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_is_contiguous(dst));
+    assert(ggml_nelements(dst) == ggml_nelements(src0));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    if (src0->nb[0] == sizeof(float)) {
+        const int ne00 = src0->ne[0];
+        const int ne01 = src0->ne[1];
+        const int ne02 = src0->ne[2];
+        const int ne03 = src0->ne[3];
+
+        const size_t nb00 = src0->nb[0];
+        const size_t nb01 = src0->nb[1];
+        const size_t nb02 = src0->nb[2];
+        const size_t nb03 = src0->nb[3];
+
+        if (dst->type == GGML_TYPE_F32) {
+            int id = 0;
+            const size_t rs = ne00*nb00;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                        char * dst_ptr = (char *) dst->data + id*rs;
+
+                        memcpy(dst_ptr, src0_ptr, rs);
+
+                        id++;
+                    }
+                }
+            }
+        } else if (dst->type == GGML_TYPE_F16) {
+            int id = 0;
+            ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = ggml_fp32_to_fp16(*src0_ptr);
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else {
+            assert(false); // TODO: implement
+        }
+    } else {
+        GGML_PRINT_DEBUG("ggml_compute_forward_dup: fix me\n"); // TODO !!!
+        const int ne00 = src0->ne[0];
+        const int ne01 = src0->ne[1];
+        const int ne02 = src0->ne[2];
+        const int ne03 = src0->ne[3];
+
+        const size_t nb00 = src0->nb[0];
+        const size_t nb01 = src0->nb[1];
+        const size_t nb02 = src0->nb[2];
+        const size_t nb03 = src0->nb[3];
+
+        int id = 0;
+        for (int i03 = 0; i03 < ne03; i03++) {
+            for (int i02 = 0; i02 < ne02; i02++) {
+                for (int i01 = 0; i01 < ne01; i01++) {
+                    for (int i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
+                              char *  dst_ptr = (char *)  dst->data + id*sizeof(float);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(float));
+
+                        id++;
+                    }
+                }
+            }
+        }
+    }
+}
+
+// ggml_compute_forward_add
+
+void ggml_compute_forward_add_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+    assert(src1->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_add_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])),
+                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    }
+}
+
+void ggml_compute_forward_add(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_add_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_sub
+
+void ggml_compute_forward_sub_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+    assert(src1->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sub_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])),
+                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    }
+}
+
+void ggml_compute_forward_sub(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sub_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_mul
+
+void ggml_compute_forward_mul_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+    assert(src1->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_mul_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])),
+                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    }
+}
+
+void ggml_compute_forward_mul(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_mul_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_div
+
+void ggml_compute_forward_div_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+    assert(src1->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_div_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])),
+                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    }
+}
+
+void ggml_compute_forward_div(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_div_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_sqr
+
+void ggml_compute_forward_sqr_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n     = ggml_nrows(src0);
+    const int nc    = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sqr_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+void ggml_compute_forward_sqr(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sqr_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_sqrt
+
+void ggml_compute_forward_sqrt_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sqrt_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+void ggml_compute_forward_sqrt(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sqrt_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_sum
+
+void ggml_compute_forward_sum_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_is_scalar(dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    assert(ggml_is_scalar(dst));
+    assert(src0->nb[0] == sizeof(float));
+
+    *(float *) (dst->data) = 0.0f;
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    for (int i03 = 0; i03 < ne03; i03++) {
+        for (int i02 = 0; i02 < ne02; i02++) {
+            for (int i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f32(ne00,
+                        (float *) (dst->data),
+                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_sum(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sum_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_mean
+
+void ggml_compute_forward_mean_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    assert(src0->nb[0] == sizeof(float));
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const int ne0 = dst->ne[0];
+    const int ne1 = dst->ne[1];
+    const int ne2 = dst->ne[2];
+    const int ne3 = dst->ne[3];
+
+    assert(ne0 == 1);
+    assert(ne1 == ne01);
+    assert(ne2 == ne02);
+    assert(ne3 == ne03);
+
+    UNUSED(ne0);
+    UNUSED(ne1);
+    UNUSED(ne2);
+    UNUSED(ne3);
+
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    for (int i03 = 0; i03 < ne03; i03++) {
+        for (int i02 = 0; i02 < ne02; i02++) {
+            for (int i01 = 0; i01 < ne01; i01++) {
+                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) = 0.0f;
+
+                ggml_vec_sum_f32(ne00,
+                        (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+
+                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_mean(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_mean_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_repeat
+
+void ggml_compute_forward_repeat_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_can_repeat(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // TODO: implement support for rank > 2 tensors
+    assert(src0->ne[2] == 1);
+    assert(src0->ne[3] == 1);
+    assert( dst->ne[2] == 1);
+    assert( dst->ne[3] == 1);
+
+    const int nc  = dst->ne[0];
+    const int nr  = dst->ne[1];
+    const int nc0 = src0->ne[0];
+    const int nr0 = src0->ne[1];
+    const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
+
+    // TODO: support for transposed / permuted tensors
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    // TODO: maybe this is not optimal?
+    for (int i = 0; i < nrr; i++) {
+        for (int j = 0; j < ncr; j++) {
+            for (int k = 0; k < nr0; k++) {
+                ggml_vec_cpy_f32(nc0,
+                        (float *) ((char *)  dst->data + (i*nr0 + k)*( dst->nb[1]) + j*nc0*( dst->nb[0])),
+                        (float *) ((char *) src0->data + (        k)*(src0->nb[1])));
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_repeat(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_repeat_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_abs
+
+void ggml_compute_forward_abs_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_abs_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+void ggml_compute_forward_abs(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_abs_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_sgn
+
+void ggml_compute_forward_sgn_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sgn_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+void ggml_compute_forward_sgn(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sgn_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_neg
+
+void ggml_compute_forward_neg_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_neg_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+void ggml_compute_forward_neg(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_neg_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_step
+
+void ggml_compute_forward_step_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_step_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+void ggml_compute_forward_step(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_step_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_relu
+
+void ggml_compute_forward_relu_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_relu_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+void ggml_compute_forward_relu(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_relu_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_gelu
+
+void ggml_compute_forward_gelu_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_gelu_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data  + i*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+void ggml_compute_forward_gelu(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gelu_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_norm
+
+void ggml_compute_forward_norm_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    assert(src0->nb[0] == sizeof(float));
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    const ggml_float eps = 1e-5f; // TODO: make this a parameter
+
+    // TODO: optimize
+    for (int i03 = 0; i03 < ne03; i03++) {
+        for (int i02 = 0; i02 < ne02; i02++) {
+            for (int i01 = 0; i01 < ne01; i01++) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float mean = 0.0;
+                for (int i00 = 0; i00 < ne00; i00++) {
+                    mean += x[i00];
+                }
+
+                mean /= ne00;
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                ggml_float sum2 = 0.0;
+                for (int i00 = 0; i00 < ne00; i00++) {
+                    ggml_float v = x[i00] - mean;
+                    y[i00] = v;
+                    sum2 += v*v;
+                }
+
+                const float scale = 1.0/sqrt(sum2/ne00 + eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_norm(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_norm_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_mul_mat
+
+void ggml_compute_forward_mul_mat_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    const int ne3  = dst->ne[3];
+    const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    assert(ne02 == ne12);
+    assert(ne03 == ne13);
+    assert(ne2  == ne12);
+    assert(ne3  == ne13);
+
+    // TODO: we don't support permuted src0
+    assert(nb00 == sizeof(float) || nb01 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    assert(nb0 == sizeof(float));
+    assert(nb0 <= nb1);
+    assert(nb1 <= nb2);
+    assert(nb2 <= nb3);
+
+    assert(ne0 == ne01);
+    assert(ne1 == ne11);
+    assert(ne2 == ne02);
+    assert(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+    //
+    // nb00 <  nb01 - src0 is transposed
+    //   compute by src0 columns
+
+    if (params->type == GGML_TASK_INIT) {
+        if (nb01 >= nb00) {
+            return;
+        }
+
+        // TODO: fix this memset (wsize is overestimated)
+        memset(params->wdata, 0, params->wsize);
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        if (nb01 >= nb00) {
+            return;
+        }
+
+        // TODO: fix this memset (wsize is overestimated)
+        //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
+
+        float * const wdata = params->wdata;
+
+        ggml_vec_cpy_f32(ne, dst->data, wdata);
+
+        for (int k = 1; k < nth; k++) {
+            ggml_vec_acc_f32(ne, dst->data, wdata + (ne + CACHE_LINE_SIZE_F32)*k);
+        }
+
+        return;
+    }
+
+    if (nb01 >= nb00) {
+        // TODO: do not support transposed src1
+        assert(nb10 == sizeof(float));
+
+        // parallelize by src0 rows using ggml_vec_dot_f32
+
+        // total rows in src0
+        const int nr = ne01*ne02*ne03;
+
+        // rows per thread
+        const int dr = (nr + nth - 1)/nth;
+
+        // row range for this thread
+        const int ir0 = dr*ith;
+        const int ir1 = MIN(ir0 + dr, nr);
+
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0 indices
+            const int i03 = ir/(ne02*ne01);
+            const int i02 = (ir - i03*ne02*ne01)/ne01;
+            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            for (int ic = 0; ic < ne11; ++ic) {
+                // src1 indices
+                const int i13 = i03;
+                const int i12 = i02;
+                const int i11 = ic;
+
+                // dst indices
+                const int i0 = i01;
+                const int i1 = i11;
+                const int i2 = i02;
+                const int i3 = i03;
+
+                ggml_vec_dot_f32(ne00,
+                        (float *) ((char *)  dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+                        (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
+                        (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
+            }
+        }
+    } else {
+        // parallelize by src1 columns using ggml_vec_mad_f32
+        // each thread has its own work data
+        // during FINALIZE we accumulate all work data into dst
+
+        // total columns in src1
+        const int nc = ne10;
+
+        // columns per thread
+        const int dc = (nc + nth - 1)/nth;
+
+        // column range for this thread
+        const int ic0 = dc*ith;
+        const int ic1 = MIN(ic0 + dc, nc);
+
+        // work data for thread
+        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
+        float * const wdata = params->wdata;
+
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                for (int i11 = 0; i11 < ne11; ++i11) {
+                    for (int ic = ic0; ic < ic1; ++ic) {
+                        // src1 indices
+                        const int i10 = ic;
+
+                        // src0 indices
+                        const int i03 = i13;
+                        const int i02 = i12;
+                        const int i00 = ic;
+
+                        // dst indices
+                        const int i1 = i11;
+                        const int i2 = i12;
+                        const int i3 = i13;
+
+                        assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+
+                        ggml_vec_mad_f32(ne01,
+                                (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0),
+                                (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)),
+                               *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)));
+                    }
+                }
+            }
+        }
+    }
+
+    //int64_t t1 = ggml_time_us();
+    //static int64_t acc = 0;
+    //acc += t1 - t0;
+    //if (t1 - t0 > 10) {
+    //    printf("\n");
+    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
+
+    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+    //}
+}
+
+void ggml_compute_forward_mul_mat_f16_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    const int ne3  = dst->ne[3];
+    const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    assert(ne02 == ne12);
+    assert(ne03 == ne13);
+    assert(ne2  == ne12);
+    assert(ne3  == ne13);
+
+    // TODO: we don't support permuted src0
+    assert(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t));
+
+    // dst cannot be transposed or permuted
+    assert(nb0 == sizeof(float));
+    assert(nb0 <= nb1);
+    assert(nb1 <= nb2);
+    assert(nb2 <= nb3);
+
+    assert(ne0 == ne01);
+    assert(ne1 == ne11);
+    assert(ne2 == ne02);
+    assert(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+    //
+    // nb00 <  nb01 - src0 is transposed
+    //   compute by src0 columns
+
+    if (params->type == GGML_TASK_INIT) {
+        if (nb01 >= nb00) {
+            ggml_fp16_t * const wdata = params->wdata;
+
+            for (int i = 0; i < ne10*ne11*ne12*ne13; ++i) {
+                wdata[i] = ggml_fp32_to_fp16(((float *) src1->data)[i]);
+            }
+
+            return;
+        }
+
+        // TODO: fix this memset (wsize is overestimated)
+        memset(params->wdata, 0, params->wsize);
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        if (nb01 >= nb00) {
+            return;
+        }
+
+        // TODO: fix this memset (wsize is overestimated)
+        //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
+
+        ggml_fp16_t * const wdata = params->wdata;
+
+        for (int i = 0; i < ne; ++i) {
+            ((float *) dst->data)[i] = ggml_fp16_to_fp32(wdata[i]);
+        }
+
+        for (int k = 1; k < nth; k++) {
+            for (int i = 0; i < ne; ++i) {
+                ((float *) dst->data)[i] += ggml_fp16_to_fp32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]);
+            }
+        }
+
+        return;
+    }
+
+    if (nb01 >= nb00) {
+        // fp16 -> half the size, so divide by 2
+        const int nb10 = src1->nb[0]/2; UNUSED(nb10);
+
+        // TODO: do not support transposed src1
+        assert(nb10 == sizeof(ggml_fp16_t));
+
+        // parallelize by src0 rows using ggml_vec_dot_f32
+
+        // total rows in src0
+        const int nr = ne01*ne02*ne03;
+
+        // rows per thread
+        const int dr = (nr + nth - 1)/nth;
+
+        // row range for this thread
+        const int ir0 = dr*ith;
+        const int ir1 = MIN(ir0 + dr, nr);
+
+        ggml_fp16_t * wdata = params->wdata;
+
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0 indices
+            const int i03 = ir/(ne02*ne01);
+            const int i02 = (ir - i03*ne02*ne01)/ne01;
+            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+
+            for (int ic = 0; ic < ne11; ++ic) {
+                // src1 indices
+                const int i13 = i03;
+                const int i12 = i02;
+                const int i11 = ic;
+
+                // dst indices
+                const int i0 = i01;
+                const int i1 = i11;
+                const int i2 = i02;
+                const int i3 = i03;
+
+                assert(ne00 % 64 == 0);
+
+                ggml_fp16_t * src1_col = wdata + (i13*ne12*ne11 + i12*ne11 + i11)*ne00;
+
+                float * dst_row = (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3));
+
+                ggml_vec_dot_f16(ne00, dst_row, src0_row, src1_col);
+            }
+        }
+    } else {
+        // parallelize by src1 columns using ggml_vec_mad_f32
+        // each thread has its own work data
+        // during FINALIZE we accumulate all work data into dst
+
+        const int nb10 = src1->nb[0];
+        const int nb11 = src1->nb[1];
+        const int nb12 = src1->nb[2];
+        const int nb13 = src1->nb[3];
+
+        // total columns in src1
+        const int nc = ne10;
+
+        // columns per thread
+        const int dc = (nc + nth - 1)/nth;
+
+        // column range for this thread
+        const int ic0 = dc*ith;
+        const int ic1 = MIN(ic0 + dc, nc);
+
+        // work data for thread
+        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
+        ggml_fp16_t * const wdata = params->wdata;
+
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                for (int i11 = 0; i11 < ne11; ++i11) {
+                    // dst indices
+                    const int i1 = i11;
+                    const int i2 = i12;
+                    const int i3 = i13;
+
+                    ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
+
+                    for (int ic = ic0; ic < ic1; ++ic) {
+                        // src1 indices
+                        const int i10 = ic;
+
+                        // src0 indices
+                        const int i03 = i13;
+                        const int i02 = i12;
+                        const int i00 = ic;
+
+                        assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+
+                        ggml_fp16_t * src0_col =  (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
+                        float         src1_val = *      (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+
+                        ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val);
+                    }
+                }
+            }
+        }
+    }
+
+    //int64_t t1 = ggml_time_us();
+    //static int64_t acc = 0;
+    //acc += t1 - t0;
+    //if (t1 - t0 > 10) {
+    //    printf("\n");
+    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+
+    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+    //}
+}
+
+void ggml_compute_forward_mul_mat(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+// ggml_compute_forward_scale
+
+void ggml_compute_forward_scale_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_is_scalar(src1));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+    assert(src1->nb[0] == sizeof(float));
+
+    const float v = *(float *) src1->data;
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i*(dst->nb[1])), v);
+    }
+}
+
+void ggml_compute_forward_scale(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_scale_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_cpy
+
+void ggml_compute_forward_cpy(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    ggml_compute_forward_dup(params, src0, dst);
+}
+
+// ggml_compute_forward_reshape
+
+void ggml_compute_forward_reshape(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    // NOP
+    UNUSED(params);
+    UNUSED(src0);
+    UNUSED(dst);
+}
+
+// ggml_compute_forward_view
+
+void ggml_compute_forward_view(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0) {
+    // NOP
+    UNUSED(params);
+    UNUSED(src0);
+}
+
+// ggml_compute_forward_permute
+
+void ggml_compute_forward_permute(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0) {
+    // NOP
+    UNUSED(params);
+    UNUSED(src0);
+}
+
+// ggml_compute_forward_transpose
+
+void ggml_compute_forward_transpose(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0) {
+    // NOP
+    UNUSED(params);
+    UNUSED(src0);
+}
+
+// ggml_compute_forward_get_rows
+
+void ggml_compute_forward_get_rows_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    assert( dst->ne[0] == nc);
+    assert( dst->ne[1] == nr);
+    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        for (int j = 0; j < nc; ++j) {
+            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
+            ((float *) ((char *)  dst->data + i*dst->nb[1]))[j] = ggml_fp16_to_fp32(v);
+        }
+    }
+}
+
+void ggml_compute_forward_get_rows_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    assert( dst->ne[0] == nc);
+    assert( dst->ne[1] == nr);
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        ggml_vec_cpy_f32(nc,
+                (float *) ((char *)  dst->data + i*dst->nb[1]),
+                (float *) ((char *) src0->data + r*src0->nb[1]));
+    }
+}
+
+void ggml_compute_forward_get_rows(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_diag_mask_inf
+
+void ggml_compute_forward_diag_mask_inf_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 1);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n_past = ((int32_t *) src1->data)[0];
+
+    // TODO: handle transposed/permuted matrices
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+    const int nr = src0->ne[1];
+    const int nz = n/nr;
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int k = 0; k < nz; k++) {
+        for (int j = 0; j < nr; j++) {
+            for (int i = n_past; i < nc; i++) {
+                if (i > n_past + j) {
+                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = -INFINITY;
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_diag_mask_inf(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_soft_max
+
+void ggml_compute_forward_soft_max_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // TODO: handle transposed/permuted matrices
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+    const int nr = src0->ne[1];
+    const int nz = n/nr;
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int k = 0; k < nz; k++) {
+        for (int j = 0; j < nr; j++) {
+            float *p = (float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1]);
+
+#ifndef NDEBUG
+            for (int i = 0; i < nc; ++i) {
+                assert(!isnan(p[i]));
+            }
+#endif
+
+            float max = -INFINITY;
+            for (int i = 0; i < nc; i++) {
+                max = MAX(max, p[i]);
+            }
+
+            ggml_float sum = 0.0;
+            for (int i = 0; i < nc; i++) {
+                const ggml_float v = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
+                sum += v;
+                p[i] = v;
+            }
+
+            assert(sum > 0.0f);
+
+            sum = 1.0/sum;
+            ggml_vec_scale_f32(nc, p, sum);
+
+#ifndef NDEBUG
+            for (int i = 0; i < nc; ++i) {
+                assert(!isnan(p[i]));
+                assert(!isinf(p[i]));
+            }
+#endif
+        }
+    }
+}
+
+void ggml_compute_forward_soft_max(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_soft_max_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_rope
+
+void ggml_compute_forward_rope_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 3);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_dims = ((int32_t *) src1->data)[1];
+    const int mode   = ((int32_t *) src1->data)[2];
+
+    //const int ne0 = src0->ne[0];
+    const int ne1 = src0->ne[1];
+    const int ne2 = src0->ne[2];
+    const int ne3 = src0->ne[3];
+
+    const int nb0 = src0->nb[0];
+    const int nb1 = src0->nb[1];
+    const int nb2 = src0->nb[2];
+    const int nb3 = src0->nb[3];
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    assert(nb0 == sizeof(float));
+
+    // TODO: optimize
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
+            const int p = (mode == 0 ? n_past + i2 : i2);
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < n_dims; i0 += 2) {
+                    const double theta = pow(10000.0, ((double)-i0)/n_dims);
+
+                    const double cos_theta = cos(p*theta);
+                    const double sin_theta = sin(p*theta);
+
+                    const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                          float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+                    double x0 = src[0];
+                    double x1 = src[1];
+
+                    dst_data[0] = x0*cos_theta - x1*sin_theta;
+                    dst_data[1] = x0*sin_theta + x1*cos_theta;
+                }
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_rope(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rope_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+/////////////////////////////////
+
+void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+    assert(params);
+
+    switch (tensor->op) {
+        case GGML_OP_DUP:
+            {
+                ggml_compute_forward_dup(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_ADD:
+            {
+                ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_SUB:
+            {
+                ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_MUL:
+            {
+                ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_DIV:
+            {
+                ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_SQR:
+            {
+                ggml_compute_forward_sqr(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_SQRT:
+            {
+                ggml_compute_forward_sqrt(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_SUM:
+            {
+                ggml_compute_forward_sum(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_MEAN:
+            {
+                ggml_compute_forward_mean(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_REPEAT:
+            {
+                ggml_compute_forward_repeat(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_ABS:
+            {
+                ggml_compute_forward_abs(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_SGN:
+            {
+                ggml_compute_forward_sgn(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_NEG:
+            {
+                ggml_compute_forward_neg(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_STEP:
+            {
+                ggml_compute_forward_step(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_RELU:
+            {
+                ggml_compute_forward_relu(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_GELU:
+            {
+                ggml_compute_forward_gelu(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_NORM:
+            {
+                ggml_compute_forward_norm(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_MUL_MAT:
+            {
+                ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_SCALE:
+            {
+                ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_CPY:
+            {
+                ggml_compute_forward_cpy(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_RESHAPE:
+            {
+                ggml_compute_forward_reshape(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_VIEW:
+            {
+                ggml_compute_forward_view(params, tensor->src0);
+            } break;
+        case GGML_OP_PERMUTE:
+            {
+                ggml_compute_forward_permute(params, tensor->src0);
+            } break;
+        case GGML_OP_TRANSPOSE:
+            {
+                ggml_compute_forward_transpose(params, tensor->src0);
+            } break;
+        case GGML_OP_GET_ROWS:
+            {
+                ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_DIAG_MASK_INF:
+            {
+                ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_SOFT_MAX:
+            {
+                ggml_compute_forward_soft_max(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_ROPE:
+            {
+                ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_NONE:
+            {
+                // nop
+            } break;
+        case GGML_OP_COUNT:
+            {
+                assert(false);
+            } break;
+    };
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
+    struct ggml_tensor * src0 = tensor->src0;
+    struct ggml_tensor * src1 = tensor->src1;
+
+    switch (tensor->op) {
+        case GGML_OP_DUP:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                }
+            } break;
+        case GGML_OP_ADD:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                }
+                if (src1->grad) {
+                    src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
+                }
+            } break;
+        case GGML_OP_SUB:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                }
+                if (src1->grad) {
+                    src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace);
+                }
+            } break;
+        case GGML_OP_MUL:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_mul(ctx, src1, tensor->grad),
+                                inplace);
+                }
+                if (src1->grad) {
+                    src1->grad =
+                        ggml_add_impl(ctx,
+                                src1->grad,
+                                ggml_mul(ctx, src0, tensor->grad),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_DIV:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_div(ctx, tensor->grad, src1),
+                                inplace);
+                }
+                if (src1->grad) {
+                    src1->grad =
+                        ggml_sub_impl(ctx,
+                                src1->grad,
+                                ggml_mul(ctx,
+                                    tensor->grad,
+                                    ggml_div(ctx, tensor, src1)),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_SQR:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_mul(ctx,
+                                    ggml_mul(ctx, src0, tensor->grad),
+                                    ggml_repeat(ctx, ggml_new_f32(ctx, 2.0f), src0)),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_SQRT:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_div(ctx,
+                                    ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
+                                    tensor),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_SUM:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_repeat(ctx, tensor->grad, src0->grad),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_MEAN:
+            {
+                assert(false); // TODO: implement
+            } break;
+        case GGML_OP_REPEAT:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_sum(ctx, tensor->grad),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_ABS:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_mul(ctx,
+                                    ggml_sgn(ctx, src0),
+                                    tensor->grad),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_SGN:
+            {
+                if (src0->grad) {
+                    // noop
+                }
+            } break;
+        case GGML_OP_NEG:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
+                }
+            } break;
+        case GGML_OP_STEP:
+            {
+                if (src0->grad) {
+                    // noop
+                }
+            } break;
+        case GGML_OP_RELU:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_sub_impl(ctx,
+                            src0->grad,
+                            ggml_mul(ctx,
+                                ggml_step(ctx, src0),
+                                tensor->grad),
+                            inplace);
+                }
+            } break;
+        case GGML_OP_GELU:
+            {
+                assert(false); // TODO: not implemented
+            } break;
+        case GGML_OP_NORM:
+            {
+                assert(false); // TODO: not implemented
+            } break;
+        case GGML_OP_MUL_MAT:
+            {
+                if (src0->grad) {
+                    // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
+                    assert(false);
+                }
+                if (src1->grad) {
+                    src1->grad =
+                        ggml_add_impl(ctx,
+                                src1->grad,
+                                // TODO: fix transpose, the node will break the graph connections
+                                ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_SCALE:
+            {
+                assert(false); // TODO: not implemented
+            } break;
+        case GGML_OP_CPY:
+            {
+                assert(false); // TODO: not implemented
+            } break;
+        case GGML_OP_RESHAPE:
+            {
+                assert(false); // TODO: not implemented
+            } break;
+        case GGML_OP_VIEW:
+            {
+                assert(false); // not supported
+            } break;
+        case GGML_OP_PERMUTE:
+            {
+                assert(false); // TODO: not implemented
+            } break;
+        case GGML_OP_TRANSPOSE:
+            {
+                assert(false); // TODO: not implemented
+            } break;
+        case GGML_OP_GET_ROWS:
+            {
+                assert(false); // TODO: not implemented
+            } break;
+        case GGML_OP_DIAG_MASK_INF:
+            {
+                assert(false); // TODO: not implemented
+            } break;
+        case GGML_OP_SOFT_MAX:
+            {
+                assert(false); // TODO: not implemented
+            } break;
+        case GGML_OP_ROPE:
+            {
+                assert(false); // TODO: not implemented
+            } break;
+        case GGML_OP_NONE:
+            {
+                // nop
+            } break;
+        case GGML_OP_COUNT:
+            {
+                assert(false);
+            } break;
+    };
+}
+
+void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
+    if (node->grad == NULL) {
+        // this usually happens when we generate intermediate nodes from constants in the backward pass
+        // it can also happen during forward pass, if the user performs computations with constants
+        if (node->op != GGML_OP_NONE) {
+            //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op);
+        }
+    }
+
+    // check if already visited
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        if (cgraph->nodes[i] == node) {
+            return;
+        }
+    }
+
+    for (int i = 0; i < cgraph->n_leafs; i++) {
+        if (cgraph->leafs[i] == node) {
+            return;
+        }
+    }
+
+    if (node->src0) {
+        ggml_visit_parents(cgraph, node->src0);
+    }
+
+    if (node->src1) {
+        ggml_visit_parents(cgraph, node->src1);
+    }
+
+    if (node->op == GGML_OP_NONE && node->grad == NULL) {
+        // reached a leaf node, not part of the gradient graph (e.g. a constant)
+        assert(cgraph->n_leafs < GGML_MAX_NODES);
+
+        cgraph->leafs[cgraph->n_leafs] = node;
+        cgraph->n_leafs++;
+    } else {
+        assert(cgraph->n_nodes < GGML_MAX_NODES);
+
+        cgraph->nodes[cgraph->n_nodes] = node;
+        cgraph->grads[cgraph->n_nodes] = node->grad;
+        cgraph->n_nodes++;
+    }
+}
+
+void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
+    if (!expand) {
+        cgraph->n_nodes = 0;
+        cgraph->n_leafs = 0;
+    }
+
+    const int n0 = cgraph->n_nodes;
+    UNUSED(n0);
+
+    ggml_visit_parents(cgraph, tensor);
+
+    const int n_new = cgraph->n_nodes - n0;
+    GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
+
+    if (n_new > 0) {
+        // the last added node should always be starting point
+        assert(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
+    }
+}
+
+void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
+    ggml_build_forward_impl(cgraph, tensor, true);
+}
+
+struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
+    struct ggml_cgraph result = {
+        /*.n_nodes      =*/ 0,
+        /*.n_leafs      =*/ 0,
+        /*.n_threads    =*/ 0,
+        /*.work_size    =*/ 0,
+        /*.work         =*/ NULL,
+        /*.nodes        =*/ { NULL },
+        /*.grads        =*/ { NULL },
+        /*.leafs        =*/ { NULL },
+        /*.perf_runs    =*/ 0,
+        /*.perf_cycles  =*/ 0,
+        /*.perf_time_us =*/ 0,
+    };
+
+    ggml_build_forward_impl(&result, tensor, false);
+
+    return result;
+}
+
+struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
+    struct ggml_cgraph result = *gf;
+
+    assert(gf->n_nodes > 0);
+
+    // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
+    if (keep) {
+        for (int i = 0; i < gf->n_nodes; i++) {
+            struct ggml_tensor * node = gf->nodes[i];
+
+            if (node->grad) {
+                node->grad = ggml_dup_tensor(ctx, node);
+                gf->grads[i] = node->grad;
+            }
+        }
+    }
+
+    for (int i = gf->n_nodes - 1; i >= 0; i--) {
+        struct ggml_tensor * node = gf->nodes[i];
+
+        // because we detached the grad nodes from the original graph, we can afford inplace operations
+        if (node->grad) {
+            ggml_compute_backward(ctx, node, keep);
+        }
+    }
+
+    for (int i = gf->n_nodes - 1; i >= 0; i--) {
+        struct ggml_tensor * node = gf->nodes[i];
+
+        if (node->is_param) {
+            GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
+            ggml_build_forward_impl(&result, node->grad, true);
+        }
+    }
+
+    return result;
+}
+
+//
+// thread data
+//
+// synchronization is done via busy loops
+// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops
+//
+
+#ifdef __APPLE__
+
+//#include <os/lock.h>
+
+//typedef os_unfair_lock ggml_lock_t;
+//
+//#define ggml_lock_init(x)    UNUSED(x)
+//#define ggml_lock_destroy(x) UNUSED(x)
+//#define ggml_lock_lock       os_unfair_lock_lock
+//#define ggml_lock_unlock     os_unfair_lock_unlock
+//
+//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT
+
+typedef int ggml_lock_t;
+
+#define ggml_lock_init(x)    UNUSED(x)
+#define ggml_lock_destroy(x) UNUSED(x)
+#define ggml_lock_lock(x)    UNUSED(x)
+#define ggml_lock_unlock(x)  UNUSED(x)
+
+#define GGML_LOCK_INITIALIZER 0
+
+#else
+
+//typedef pthread_spinlock_t ggml_lock_t;
+
+//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE)
+//#define ggml_lock_destroy pthread_spin_destroy
+//#define ggml_lock_lock    pthread_spin_lock
+//#define ggml_lock_unlock  pthread_spin_unlock
+
+typedef int ggml_lock_t;
+
+#define ggml_lock_init(x)    UNUSED(x)
+#define ggml_lock_destroy(x) UNUSED(x)
+#define ggml_lock_lock(x)    UNUSED(x)
+#define ggml_lock_unlock(x)  UNUSED(x)
+
+#define GGML_LOCK_INITIALIZER 0
+
+#endif
+
+struct ggml_compute_state_shared {
+    ggml_lock_t spin;
+
+    int n_threads;
+
+    // synchronization primitives
+    atomic_int  n_ready;
+    atomic_bool has_work;
+    atomic_bool stop; // stop all threads
+};
+
+struct ggml_compute_state {
+    pthread_t thrd;
+
+    struct ggml_compute_params params;
+    struct ggml_tensor * node;
+
+    struct ggml_compute_state_shared * shared;
+};
+
+// function used by each compute thread
+void * ggml_graph_compute_one(void * data) {
+    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+
+    ggml_compute_forward(&state->params, state->node);
+
+    return NULL;
+}
+
+void * ggml_graph_compute_thread(void * data) {
+    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+
+    const int n_threads = state->shared->n_threads;
+
+    while (true) {
+        if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) {
+            atomic_store(&state->shared->has_work, false);
+        } else {
+            while (atomic_load(&state->shared->has_work)) {
+                if (atomic_load(&state->shared->stop)) {
+                    return NULL;
+                }
+                ggml_lock_lock  (&state->shared->spin);
+                ggml_lock_unlock(&state->shared->spin);
+            }
+        }
+
+        atomic_fetch_sub(&state->shared->n_ready, 1);
+
+        // wait for work
+        while (!atomic_load(&state->shared->has_work)) {
+            if (atomic_load(&state->shared->stop)) {
+                return NULL;
+            }
+            ggml_lock_lock  (&state->shared->spin);
+            ggml_lock_unlock(&state->shared->spin);
+        }
+
+        // check if we should stop
+        if (atomic_load(&state->shared->stop)) {
+            break;
+        }
+
+        if (state->node) {
+            ggml_compute_forward(&state->params, state->node);
+            state->node = NULL;
+        } else {
+            break;
+        }
+    }
+
+    return NULL;
+}
+
+void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
+    if (cgraph->n_threads <= 0) {
+        cgraph->n_threads = 8;
+    }
+
+    const int n_threads = cgraph->n_threads;
+
+    struct ggml_compute_state_shared state_shared = {
+        /*.spin      =*/ GGML_LOCK_INITIALIZER,
+        /*.n_threads =*/ n_threads,
+        /*.n_ready   =*/ 0,
+        /*.has_work  =*/ false,
+        /*.stop      =*/ false,
+    };
+    struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL;
+
+    // create thread pool
+    if (n_threads > 1) {
+        ggml_lock_init(&state_shared.spin);
+
+        atomic_store(&state_shared.has_work, true);
+
+        for (int j = 0; j < n_threads - 1; j++) {
+            workers[j] = (struct ggml_compute_state) {
+                .thrd   = 0,
+                .params = {
+                    .type  = GGML_TASK_COMPUTE,
+                    .ith   = j + 1,
+                    .nth   = n_threads,
+                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+                    .wdata = cgraph->work ? cgraph->work->data : NULL,
+                },
+                .node   = NULL,
+                .shared = &state_shared,
+            };
+            int rc = pthread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
+            assert(rc == 0);
+            UNUSED(rc);
+        }
+    }
+
+    // initialize tasks + work buffer
+    {
+        size_t work_size = 0;
+
+        // thread scheduling for the different operations
+        for (int i = 0; i < cgraph->n_nodes; i++) {
+            struct ggml_tensor * node = cgraph->nodes[i];
+
+            switch (node->op) {
+                case GGML_OP_DUP:
+                case GGML_OP_ADD:
+                case GGML_OP_SUB:
+                case GGML_OP_MUL:
+                case GGML_OP_DIV:
+                case GGML_OP_SQR:
+                case GGML_OP_SQRT:
+                case GGML_OP_SUM:
+                case GGML_OP_MEAN:
+                case GGML_OP_REPEAT:
+                case GGML_OP_ABS:
+                case GGML_OP_SGN:
+                case GGML_OP_NEG:
+                case GGML_OP_STEP:
+                case GGML_OP_RELU:
+                case GGML_OP_GELU:
+                case GGML_OP_NORM:
+                    {
+                        node->n_tasks = 1;
+                    } break;
+                case GGML_OP_MUL_MAT:
+                    {
+                        // TODO: use different scheduling for different matrix sizes
+                        node->n_tasks = n_threads;
+
+                        // TODO: better way to determine if the matrix is transposed
+                        if (node->src0->nb[1] < node->src0->nb[0]) {
+                            size_t cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            work_size = MAX(work_size, cur);
+                        } else {
+                            if (node->src0->type == GGML_TYPE_F16 &&
+                                node->src1->type == GGML_TYPE_F32) {
+                                size_t cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
+                                work_size = MAX(work_size, cur);
+                            }
+                        }
+                    } break;
+                case GGML_OP_SCALE:
+                case GGML_OP_CPY:
+                case GGML_OP_RESHAPE:
+                case GGML_OP_VIEW:
+                case GGML_OP_PERMUTE:
+                case GGML_OP_TRANSPOSE:
+                case GGML_OP_GET_ROWS:
+                case GGML_OP_DIAG_MASK_INF:
+                case GGML_OP_SOFT_MAX:
+                case GGML_OP_ROPE:
+                    {
+                        node->n_tasks = 1;
+                    } break;
+                case GGML_OP_NONE:
+                    {
+                        node->n_tasks = 1;
+                    } break;
+                case GGML_OP_COUNT:
+                    {
+                        assert(false);
+                    } break;
+            };
+        }
+
+        if (cgraph->work != NULL && work_size > cgraph->work_size) {
+            assert(false); // TODO: better handling
+        }
+
+        if (work_size > 0 && cgraph->work == NULL) {
+            cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
+
+            GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
+            cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
+        }
+    }
+
+    const int64_t perf_start_cycles  = ggml_cycles();
+    const int64_t perf_start_time_us = ggml_time_us();
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes);
+
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        // TODO: this could be used to avoid unnecessary computations, but it needs to be improved
+        //if (node->grad == NULL && node->perf_runs > 0) {
+        //    continue;
+        //}
+
+        const int64_t perf_node_start_cycles  = ggml_cycles();
+        const int64_t perf_node_start_time_us = ggml_time_us();
+
+        // INIT
+        struct ggml_compute_params params = {
+            /*.type  =*/ GGML_TASK_INIT,
+            /*.ith   =*/ 0,
+            /*.nth   =*/ n_threads,
+            /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+            /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+        };
+
+        ggml_compute_forward(&params, node);
+
+        // COMPUTE
+        if (node->n_tasks > 1) {
+            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+                atomic_store(&state_shared.has_work, false);
+            }
+
+            while (atomic_load(&state_shared.has_work)) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            // launch thread pool
+            for (int j = 0; j < n_threads - 1; j++) {
+                workers[j].params = (struct ggml_compute_params) {
+                    .type  = GGML_TASK_COMPUTE,
+                    .ith   = j + 1,
+                    .nth   = n_threads,
+                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+                    .wdata = cgraph->work ? cgraph->work->data : NULL,
+                };
+                workers[j].node = node;
+            }
+
+            atomic_fetch_sub(&state_shared.n_ready, 1);
+
+            while (atomic_load(&state_shared.n_ready) > 0) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            atomic_store(&state_shared.has_work, true);
+        }
+
+        params.type = GGML_TASK_COMPUTE;
+        ggml_compute_forward(&params, node);
+
+        if (node->n_tasks > 1) {
+            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+                atomic_store(&state_shared.has_work, false);
+            }
+
+            while (atomic_load(&state_shared.has_work)) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+
+            atomic_fetch_sub(&state_shared.n_ready, 1);
+
+            while (atomic_load(&state_shared.n_ready) != 0) {
+                ggml_lock_lock  (&state_shared.spin);
+                ggml_lock_unlock(&state_shared.spin);
+            }
+        }
+
+        // FINALIZE
+        params.type = GGML_TASK_FINALIZE;
+        ggml_compute_forward(&params, node);
+
+        // performance stats (node)
+        {
+            int64_t perf_cycles_cur  = ggml_cycles()  - perf_node_start_cycles;
+            int64_t perf_time_us_cur = ggml_time_us() - perf_node_start_time_us;
+
+            node->perf_runs++;
+            node->perf_cycles  += perf_cycles_cur;
+            node->perf_time_us += perf_time_us_cur;
+        }
+    }
+
+    // join thread pool
+    if (n_threads > 1) {
+        atomic_store(&state_shared.stop, true);
+        atomic_store(&state_shared.has_work, true);
+
+        for (int j = 0; j < n_threads - 1; j++) {
+            int rc = pthread_join(workers[j].thrd, NULL);
+            assert(rc == 0);
+            UNUSED(rc);
+        }
+
+        ggml_lock_destroy(&state_shared.spin);
+    }
+
+    // performance stats (graph)
+    {
+        int64_t perf_cycles_cur  = ggml_cycles()  - perf_start_cycles;
+        int64_t perf_time_us_cur = ggml_time_us() - perf_start_time_us;
+
+        cgraph->perf_runs++;
+        cgraph->perf_cycles  += perf_cycles_cur;
+        cgraph->perf_time_us += perf_time_us_cur;
+
+        GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n",
+                __func__, cgraph->perf_runs,
+                (double) perf_cycles_cur      / (double) ggml_cycles_per_ms(),
+                (double) cgraph->perf_cycles  / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs,
+                (double) perf_time_us_cur     / 1000.0,
+                (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
+    }
+}
+
+void ggml_graph_reset(struct ggml_cgraph * cgraph) {
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * grad = cgraph->grads[i];
+
+        if (grad) {
+            ggml_set_zero(grad);
+        }
+    }
+}
+
+void ggml_graph_print(const struct ggml_cgraph * cgraph) {
+    int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0};
+
+    GGML_PRINT("=== GRAPH ===\n");
+
+    GGML_PRINT_DEBUG("n_threads       = %d\n",       cgraph->n_threads);
+    GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
+
+    GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        perf_total_per_op_us[node->op] += node->perf_time_us;
+
+        GGML_PRINT(" - %3d: [ %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
+                i,
+                node->ne[0], node->ne[1],
+                GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
+                (double) node->perf_cycles  / (double) ggml_cycles_per_ms(),
+                (double) node->perf_cycles  / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
+                (double) node->perf_time_us / 1000.0,
+                (double) node->perf_time_us / 1000.0 / node->perf_runs);
+    }
+
+    GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs);
+    for (int i = 0; i < cgraph->n_leafs; i++) {
+        struct ggml_tensor * node = cgraph->leafs[i];
+
+        GGML_PRINT(" - %3d: [ %6d, %6d] %8s\n",
+                i,
+                node->ne[0], node->ne[1],
+                GGML_OP_LABEL[node->op]);
+    }
+
+    for (int i = 0; i < GGML_OP_COUNT; i++) {
+        GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
+    }
+
+    GGML_PRINT("========================================\n");
+}
+
+// check if node is part of the graph
+bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+    if (cgraph == NULL) {
+        return true;
+    }
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        if (cgraph->nodes[i] == node) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * parent = cgraph->nodes[i];
+
+        if (parent->grad == node) {
+            return parent;
+        }
+    }
+
+    return NULL;
+}
+
+void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
+    char color[16];
+
+    FILE * fp = fopen(filename, "w");
+    assert(fp);
+
+    fprintf(fp, "digraph G {\n");
+    fprintf(fp, "  newrank = true;\n");
+    fprintf(fp, "  rankdir = LR;\n");
+
+    for (int i = 0; i < gb->n_nodes; i++) {
+        struct ggml_tensor * node = gb->nodes[i];
+
+        if (ggml_graph_get_parent(gb, node) != NULL) {
+            continue;
+        }
+
+        if (node->is_param) {
+            snprintf(color, sizeof(color), "yellow");
+        } else if (node->grad) {
+            if (ggml_graph_find(gf, node)) {
+                snprintf(color, sizeof(color), "green");
+            } else {
+                snprintf(color, sizeof(color), "lightblue");
+            }
+        } else {
+            snprintf(color, sizeof(color), "white");
+        }
+
+        fprintf(fp, "  \"%p\" [ \
+style = filled; fillcolor = %s; shape = record; \
+label=\"%d [%d, %d] | <x>%s",
+                (void *) node, color,
+                i, node->ne[0], node->ne[1],
+                GGML_OP_SYMBOL[node->op]);
+
+        if (node->grad) {
+            fprintf(fp, " | <g>%s\"; ]\n", GGML_OP_SYMBOL[node->grad->op]);
+        } else {
+            fprintf(fp, "\"; ]\n");
+        }
+    }
+
+    for (int i = 0; i < gb->n_leafs; i++) {
+        struct ggml_tensor * node = gb->leafs[i];
+
+        snprintf(color, sizeof(color), "pink");
+
+        if (ggml_nelements(node) == 1) {
+            fprintf(fp, "  \"%p\" [ \
+style = filled; fillcolor = %s; shape = record; \
+label=\"<x>%.1e\"; ]\n",
+                    (void *) node, color, ggml_get_f32_1d(node, 0));
+        } else {
+            fprintf(fp, "  \"%p\" [ \
+style = filled; fillcolor = %s; shape = record; \
+label=\"<x>CONST %d [%d, %d]\"; ]\n",
+                    (void *) node, color,
+                    i, node->ne[0], node->ne[1]);
+        }
+    }
+
+    for (int i = 0; i < gb->n_nodes; i++) {
+        struct ggml_tensor * node = gb->nodes[i];
+
+        struct ggml_tensor * parent = ggml_graph_get_parent(gb, node);
+
+        if (node->src0) {
+            struct ggml_tensor * parent0 = ggml_graph_get_parent(gb, node->src0);
+
+            fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"x\"; ]\n",
+                    parent0 ? (void *) parent0 : (void *) node->src0,
+                    parent0 ? "g" : "x",
+                    parent ? (void *) parent : (void *) node,
+                    parent ? "g" : "x",
+                    parent ? "empty" : "vee",
+                    parent ? "dashed" : "solid");
+        }
+
+        if (node->src1) {
+            struct ggml_tensor * parent1 = ggml_graph_get_parent(gb, node->src1);
+
+            fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"y\"; ]\n",
+                    parent1 ? (void *) parent1 : (void *) node->src1,
+                    parent1 ? "g" : "x",
+                    parent ? (void *) parent : (void *) node,
+                    parent ? "g" : "x",
+                    parent ? "empty" : "vee",
+                    parent ? "dashed" : "solid");
+        }
+    }
+
+    for (int i = 0; i < gb->n_leafs; i++) {
+        struct ggml_tensor * node = gb->leafs[i];
+
+        if (node->src0) {
+            fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ label = \"x\"; ]\n",
+                    (void *) node->src0, "x",
+                    (void *) node, "x");
+        }
+
+        if (node->src1) {
+            fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ label = \"y\"; ]\n",
+                    (void *) node->src1, "x",
+                    (void *) node, "x");
+        }
+    }
+
+    fprintf(fp, "}\n");
+
+    fclose(fp);
+
+    GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) {
+    int i = 0;
+    for (int p = 0; p < np; ++p) {
+        const int ne = ggml_nelements(ps[p]) ;
+        // TODO: add function to set tensor from array
+        for (int j = 0; j < ne; ++j) {
+            ggml_set_f32_1d(ps[p], j, x[i++]);
+        }
+    }
+}
+
+void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) {
+    int i = 0;
+    for (int p = 0; p < np; ++p) {
+        const int ne = ggml_nelements(ps[p]) ;
+        // TODO: add function to get all elements at once
+        for (int j = 0; j < ne; ++j) {
+            x[i++] = ggml_get_f32_1d(ps[p], j);
+        }
+    }
+}
+
+void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
+    int i = 0;
+    for (int p = 0; p < np; ++p) {
+        const int ne = ggml_nelements(ps[p]) ;
+        // TODO: add function to get all elements at once
+        for (int j = 0; j < ne; ++j) {
+            g[i++] = ggml_get_f32_1d(ps[p]->grad, j);
+        }
+    }
+}
+
+//
+// ADAM
+//
+//   ref: https://arxiv.org/pdf/1412.6980.pdf
+//
+
+enum ggml_opt_result ggml_opt_adam(
+        struct ggml_context * ctx,
+        struct ggml_opt_params params,
+        struct ggml_tensor * f,
+        struct ggml_cgraph * gf,
+        struct ggml_cgraph * gb) {
+    assert(ggml_is_scalar(f));
+
+    gf->n_threads = params.n_threads;
+    gb->n_threads = params.n_threads;
+
+    // these will store the parameters we want to optimize
+    struct ggml_tensor * ps[GGML_MAX_PARAMS];
+
+    int np = 0;
+    int nx = 0;
+    for (int i = 0; i < gf->n_nodes; ++i) {
+        if (gf->nodes[i]->is_param) {
+            GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
+
+            assert(np < GGML_MAX_PARAMS);
+
+            ps[np++] = gf->nodes[i];
+            nx += ggml_nelements(gf->nodes[i]);
+        }
+    }
+
+    // constants
+    const float alpha = params.adam.alpha;
+    const float beta1 = params.adam.beta1;
+    const float beta2 = params.adam.beta2;
+    const float eps   = params.adam.eps;
+
+    float * x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters
+    float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient
+    float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared
+    float * m  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment
+    float * v  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment
+    float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat
+    float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat
+
+    float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
+
+    // initialize
+    ggml_vec_set_f32(nx, m, 0.0f);
+    ggml_vec_set_f32(nx, v, 0.0f);
+
+    // update view
+    ggml_opt_get_params(np, ps, x);
+
+    // compute the function value
+    ggml_graph_reset  (gf);
+    ggml_set_f32      (f->grad, 1.0f);
+    ggml_graph_compute(ctx, gb);
+
+    float fx_prev = ggml_get_f32_1d(f, 0);
+    if (pf) {
+        pf[0] = fx_prev;
+    }
+
+    int n_no_improvement = 0;
+    float fx_best = fx_prev;
+
+    // run the optimizer
+    for (int t = 0; t < params.adam.n_iter; ++t) {
+        GGML_PRINT_DEBUG  ("=== iter %d ===\n", t);
+
+        GGML_PRINT_DEBUG  ("f      = %10.6f\n", ggml_get_f32_1d(f, 0));
+        GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0));
+        GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0));
+
+        for (int i = 0; i < np; ++i) {
+            GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i,
+                    ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0));
+        }
+
+        const int64_t t_start_wall = ggml_time_us();
+        const int64_t t_start_cpu = ggml_cycles();
+        UNUSED(t_start_wall);
+        UNUSED(t_start_cpu);
+
+        {
+            // update the gradient
+            ggml_opt_get_grad(np, ps, g1);
+
+            // m_t = beta1*m_t-1 + (1 - beta1)*g_t
+            ggml_vec_scale_f32(nx, m, beta1);
+            ggml_vec_mad_f32  (nx, m, g1, 1.0f - beta1);
+
+            // g2 = g1^2
+            ggml_vec_sqr_f32  (nx, g2, g1);
+
+            // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2
+            ggml_vec_scale_f32(nx, v, beta2);
+            ggml_vec_mad_f32  (nx, v, g2, 1.0f - beta2);
+
+            // m^hat = m_t / (1 - beta1^t)
+            // v^hat = v_t / (1 - beta2^t)
+            // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
+            ggml_vec_cpy_f32  (nx, mh, m);
+            ggml_vec_cpy_f32  (nx, vh, v);
+
+            ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1)));
+            ggml_vec_scale_f32(nx, vh,  1.0f/(1.0f - powf(beta2, t + 1)));
+
+            ggml_vec_sqrt_f32 (nx, vh, vh);
+            ggml_vec_acc1_f32 (nx, vh, eps);
+
+            ggml_vec_div_f32  (nx, mh, mh, vh);
+            ggml_vec_sub_f32  (nx, x,  x,  mh);
+
+            // update the parameters
+            ggml_opt_set_params(np, ps, x);
+        }
+
+        ggml_graph_reset  (gf);
+        ggml_set_f32      (f->grad, 1.0f);
+        ggml_graph_compute(ctx, gb);
+
+        const float fx = ggml_get_f32_1d(f, 0);
+
+        // check convergence
+        if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
+            GGML_PRINT_DEBUG("converged\n");
+
+            return GGML_OPT_OK;
+        }
+
+        // delta-based convergence test
+        if (pf != NULL) {
+            // need at least params.past iterations to start checking for convergence
+            if (params.past <= t) {
+                const float rate = (pf[t%params.past] - fx)/fx;
+
+                if (fabs(rate) < params.delta) {
+                    return GGML_OPT_OK;
+                }
+            }
+
+            pf[t%params.past] = fx;
+        }
+
+        // check for improvement
+        if (params.max_no_improvement > 0) {
+            if (fx_best > fx) {
+                fx_best = fx;
+                n_no_improvement = 0;
+            } else {
+                ++n_no_improvement;
+
+                if (n_no_improvement >= params.max_no_improvement) {
+                    return GGML_OPT_OK;
+                }
+            }
+        }
+
+        fx_prev = fx;
+
+        {
+            const int64_t t_end_cpu = ggml_cycles();
+            GGML_PRINT_DEBUG("time iter:      %5.3f s\n", (t_end_cpu - t_start_cpu)/CLOCKS_PER_SEC);
+            UNUSED(t_end_cpu);
+
+            const int64_t t_end_wall = ggml_time_us();
+            GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6);
+            UNUSED(t_end_wall);
+        }
+    }
+
+    return GGML_OPT_DID_NOT_CONVERGE;
+}
+
+//
+// L-BFGS
+//
+// the L-BFGS implementation below is based on the following implementation:
+//
+//   https://github.com/chokkan/liblbfgs
+//
+
+struct ggml_lbfgs_iteration_data {
+    float alpha;
+    float ys;
+    float * s;
+    float * y;
+};
+
+static enum ggml_opt_result linesearch_backtracking(
+        struct ggml_context * ctx,
+        const struct ggml_opt_params * params,
+        int nx,
+        float * x,
+        float * fx,
+        float * g,
+        float * d,
+        float * step,
+        const float * xp,
+        struct ggml_tensor * f,
+        struct ggml_cgraph * gf,
+        struct ggml_cgraph * gb,
+        const int np,
+        struct ggml_tensor * ps[]) {
+    int count = 0;
+
+    float width  = 0.0f;
+    float dg     = 0.0f;
+    float finit  = 0.0f;
+    float dginit = 0.0f;
+    float dgtest = 0.0f;
+
+    const float dec = 0.5f;
+    const float inc = 2.1f;
+
+    if (*step <= 0.) {
+        return GGML_LINESEARCH_INVALID_PARAMETERS;
+    }
+
+    // compute the initial gradient in the search direction
+    ggml_vec_dot_f32(nx, &dginit, g, d);
+
+    // make sure that d points to a descent direction
+    if (0 < dginit) {
+        return GGML_LINESEARCH_FAIL;
+    }
+
+    // initialize local variables
+    finit = *fx;
+    dgtest = params->lbfgs.ftol*dginit;
+
+    while (true) {
+        ggml_vec_cpy_f32(nx, x, xp);
+        ggml_vec_mad_f32(nx, x, d, *step);
+
+        // evaluate the function and gradient values
+        {
+            ggml_opt_set_params(np, ps, x);
+
+            ggml_graph_reset  (gf);
+            ggml_set_f32      (f->grad, 1.0f);
+            ggml_graph_compute(ctx, gb);
+
+            ggml_opt_get_grad(np, ps, g);
+
+            *fx = ggml_get_f32_1d(f, 0);
+        }
+
+        ++count;
+
+        if (*fx > finit + (*step)*dgtest) {
+            width = dec;
+        } else {
+            // Armijo condition is satisfied
+            if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) {
+                return count;
+            }
+
+            ggml_vec_dot_f32(nx, &dg, g, d);
+
+            // check the Wolfe condition
+            if (dg < params->lbfgs.wolfe * dginit) {
+                width = inc;
+            } else {
+                if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) {
+                    // regular Wolfe conditions
+                    return count;
+                }
+
+                if(dg > -params->lbfgs.wolfe*dginit) {
+                    width = dec;
+                } else {
+                    // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE)
+                    return count;
+                }
+                return count;
+            }
+        }
+
+        if (*step < params->lbfgs.min_step) {
+            return GGML_LINESEARCH_MINIMUM_STEP;
+        }
+        if (*step > params->lbfgs.max_step) {
+            return GGML_LINESEARCH_MAXIMUM_STEP;
+        }
+        if (params->lbfgs.max_linesearch <= count) {
+            return GGML_LINESEARCH_MAXIMUM_ITERATIONS;
+        }
+
+        (*step) *= width;
+    }
+
+    return GGML_LINESEARCH_FAIL;
+}
+
+enum ggml_opt_result ggml_opt_lbfgs(
+        struct ggml_context * ctx,
+        struct ggml_opt_params params,
+        struct ggml_tensor * f,
+        struct ggml_cgraph * gf,
+        struct ggml_cgraph * gb) {
+    if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE ||
+        params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
+        if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) {
+            return GGML_OPT_INVALID_WOLFE;
+        }
+    }
+
+    gf->n_threads = params.n_threads;
+    gb->n_threads = params.n_threads;
+
+    const int m = params.lbfgs.m;
+
+    // these will store the parameters we want to optimize
+    struct ggml_tensor * ps[GGML_MAX_PARAMS];
+
+    int np = 0;
+    int nx = 0;
+    for (int i = 0; i < gf->n_nodes; ++i) {
+        if (gf->nodes[i]->is_param) {
+            GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
+
+            assert(np < GGML_MAX_PARAMS);
+
+            ps[np++] = gf->nodes[i];
+            nx += ggml_nelements(gf->nodes[i]);
+        }
+    }
+
+    float * x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters
+    float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters
+    float * g  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient
+    float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient
+    float * d  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction
+
+    float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
+
+    float fx    = 0.0f; // cost function value
+    float xnorm = 0.0f; // ||x||
+    float gnorm = 0.0f; // ||g||
+    float step  = 0.0f;
+
+    // initialize x from the graph nodes
+    ggml_opt_get_params(np, ps, x);
+
+    // the L-BFGS memory
+    struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m);
+
+    for (int i = 0; i < m; ++i) {
+        lm[i].alpha = 0.0f;
+        lm[i].ys    = 0.0f;
+        lm[i].s     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
+        lm[i].y     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
+    }
+
+    // evaluate the function value and its gradient
+    {
+        ggml_opt_set_params(np, ps, x);
+
+        ggml_graph_reset  (gf);
+        ggml_set_f32      (f->grad, 1.0f);
+        ggml_graph_compute(ctx, gb);
+
+        ggml_opt_get_grad(np, ps, g);
+
+        fx = ggml_get_f32_1d(f, 0);
+    }
+
+    if (pf) {
+        pf[0] = fx;
+    }
+
+    float fx_best = fx;
+
+    // search direction = -gradient
+    ggml_vec_neg_f32(nx, d, g);
+
+    // ||x||, ||g||
+    ggml_vec_norm_f32(nx, &xnorm, x);
+    ggml_vec_norm_f32(nx, &gnorm, g);
+
+    if (xnorm < 1.0f) {
+        xnorm = 1.0f;
+    }
+
+    // already optimized
+    if (gnorm/xnorm <= params.lbfgs.eps) {
+        return GGML_OPT_OK;
+    }
+
+    // initial step
+    ggml_vec_norm_inv_f32(nx, &step, d);
+
+    int j                = 0;
+    int k                = 1;
+    int ls               = 0;
+    int end              = 0;
+    int bound            = 0;
+    int n_no_improvement = 0;
+
+    float ys   = 0.0f;
+    float yy   = 0.0f;
+    float beta = 0.0f;
+
+    while (true) {
+        // store the current position and gradient vectors
+        ggml_vec_cpy_f32(nx, xp, x);
+        ggml_vec_cpy_f32(nx, gp, g);
+
+        ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps);
+
+        if (ls < 0) {
+            // linesearch failed - go back to the previous point and return
+            ggml_vec_cpy_f32(nx, x, xp);
+            ggml_vec_cpy_f32(nx, g, gp);
+
+            return ls;
+        }
+
+        ggml_vec_norm_f32(nx, &xnorm, x);
+        ggml_vec_norm_f32(nx, &gnorm, g);
+
+        GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0));
+
+        if (xnorm < 1.0) {
+            xnorm = 1.0;
+        }
+        if (gnorm/xnorm <= params.lbfgs.eps) {
+            // converged
+            return GGML_OPT_OK;
+        }
+
+        // delta-based convergence test
+        if (pf != NULL) {
+            // need at least params.past iterations to start checking for convergence
+            if (params.past <= k) {
+                const float rate = (pf[k%params.past] - fx)/fx;
+
+                if (fabs(rate) < params.delta) {
+                    return GGML_OPT_OK;
+                }
+            }
+
+            pf[k%params.past] = fx;
+        }
+
+        // check for improvement
+        if (params.max_no_improvement > 0) {
+            if (fx < fx_best) {
+                fx_best = fx;
+                n_no_improvement = 0;
+            } else {
+                n_no_improvement++;
+
+                if (n_no_improvement >= params.max_no_improvement) {
+                    return GGML_OPT_OK;
+                }
+            }
+        }
+
+        if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) {
+            // reached the maximum number of iterations
+            return GGML_OPT_DID_NOT_CONVERGE;
+        }
+
+        // update vectors s and y:
+        //   s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
+        //   y_{k+1} = g_{k+1} - g_{k}.
+        //
+        ggml_vec_sub_f32(nx, lm[end].s, x, xp);
+        ggml_vec_sub_f32(nx, lm[end].y, g, gp);
+
+        // compute scalars ys and yy:
+        //     ys = y^t \cdot s    -> 1 / \rho.
+        //     yy = y^t \cdot y.
+        //
+        ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s);
+        ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y);
+
+        lm[end].ys = ys;
+
+        // find new search direction
+        //   ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
+
+        bound = (m <= k) ? m : k;
+        k++;
+        end = (end + 1)%m;
+
+        // initialize search direction with -g
+        ggml_vec_neg_f32(nx, d, g);
+
+        j = end;
+        for (int i = 0; i < bound; ++i) {
+            j = (j + m - 1) % m;
+            // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
+            ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d);
+            lm[j].alpha /= lm[j].ys;
+            // q_{i} = q_{i+1} - \alpha_{i} y_{i}
+            ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha);
+        }
+
+        ggml_vec_scale_f32(nx, d, ys/yy);
+
+        for (int i = 0; i < bound; ++i) {
+            // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
+            ggml_vec_dot_f32(nx, &beta, lm[j].y, d);
+            beta /= lm[j].ys;
+            // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
+            ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta);
+            j = (j + 1)%m;
+        }
+
+        step = 1.0;
+    }
+
+    return GGML_OPT_DID_NOT_CONVERGE;
+}
+
+struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
+    struct ggml_opt_params result;
+
+    switch (type) {
+        case GGML_OPT_ADAM:
+            {
+                result = (struct ggml_opt_params) {
+                    .type      = GGML_OPT_ADAM,
+                    .n_threads = 1,
+                    .past      = 0,
+                    .delta     = 1e-5f,
+
+                    .max_no_improvement = 100,
+
+                    .print_forward_graph  = true,
+                    .print_backward_graph = true,
+
+                    .adam = {
+                        .n_iter = 10000,
+                        .alpha  = 0.001f,
+                        .beta1  = 0.9f,
+                        .beta2  = 0.999f,
+                        .eps    = 1e-8f,
+                        .eps_f  = 1e-5f,
+                        .eps_g  = 1e-3f,
+                    },
+                };
+            } break;
+        case GGML_OPT_LBFGS:
+            {
+                result = (struct ggml_opt_params) {
+                    .type      = GGML_OPT_LBFGS,
+                    .n_threads = 1,
+                    .past      = 0,
+                    .delta     = 1e-5f,
+
+                    .max_no_improvement = 0,
+
+                    .print_forward_graph  = true,
+                    .print_backward_graph = true,
+
+                    .lbfgs = {
+                        .m              = 6,
+                        .n_iter         = 100,
+                        .max_linesearch = 20,
+
+                        .eps      = 1e-5f,
+                        .ftol     = 1e-4f,
+                        .wolfe    = 0.9f,
+                        .min_step = 1e-20f,
+                        .max_step = 1e+20f,
+
+                        .linesearch = GGML_LINESEARCH_DEFAULT,
+                    },
+                };
+            } break;
+    }
+
+    return result;
+}
+
+enum ggml_opt_result ggml_opt(
+        struct ggml_context * ctx,
+        struct ggml_opt_params params,
+        struct ggml_tensor * f) {
+    bool free_ctx = false;
+    if (ctx == NULL) {
+        struct ggml_init_params params_ctx = {
+            .mem_size   = 16*1024*1024,
+            .mem_buffer = NULL,
+        };
+
+        ctx = ggml_init(params_ctx);
+        if (ctx == NULL) {
+            return GGML_OPT_NO_CONTEXT;
+        }
+
+        free_ctx = true;
+    }
+
+    enum ggml_opt_result result = GGML_OPT_OK;
+
+    // build forward + backward compute graphs
+    struct ggml_cgraph gf = ggml_build_forward (f);
+    struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, false);
+
+    switch (params.type) {
+        case GGML_OPT_ADAM:
+            {
+                result = ggml_opt_adam(ctx, params, f, &gf, &gb);
+            } break;
+        case GGML_OPT_LBFGS:
+            {
+                result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb);
+            } break;
+    }
+
+    if (params.print_forward_graph) {
+        ggml_graph_print   (&gf);
+        ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot");
+    }
+
+    if (params.print_backward_graph) {
+        ggml_graph_print   (&gb);
+        ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
+    }
+
+    if (free_ctx) {
+        ggml_free(ctx);
+    }
+
+    return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
new file mode 100644 (file)
index 0000000..5716e2d
--- /dev/null
@@ -0,0 +1,74 @@
+#
+# test-vec0
+
+set(TEST_TARGET test-vec0)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+
+#
+# test-vec1 (x86)
+if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86")
+    set(TEST_TARGET test-vec1)
+    add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+    target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+    set_target_properties(${TEST_TARGET} PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma -mf16c")
+endif()
+
+#
+# test-vec2 (arm)
+if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm")
+    set(TEST_TARGET test-vec2)
+    add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+    target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+endif()
+
+#
+# test-grad0
+
+set(TEST_TARGET test-grad0)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+
+#
+# test-mul-mat
+
+set(TEST_TARGET test-mul-mat0)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+
+#
+# test0
+
+set(TEST_TARGET test0)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+
+#
+# test1
+
+set(TEST_TARGET test1)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+
+#
+# test2
+
+set(TEST_TARGET test2)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+
+#
+# test3
+
+set(TEST_TARGET test3)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
diff --git a/tests/test-grad0.c b/tests/test-grad0.c
new file mode 100644 (file)
index 0000000..4814b59
--- /dev/null
@@ -0,0 +1,378 @@
+#include "ggml/ggml.h"
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+
+#define MAX_NARGS 2
+
+float frand() {
+    return (float)rand()/(float)RAND_MAX;
+}
+
+int irand(int n) {
+    return rand()%n;
+}
+
+void get_random_dims(int * dims, int ndims) {
+    dims[0] = dims[1] = dims[2] = dims[3] = 1;
+
+    for (int i = 0; i < ndims; i++) {
+        dims[i] = 1 + irand(4);
+    }
+}
+
+struct ggml_tensor * get_random_tensor(
+        struct ggml_context * ctx0,
+        int ndims,
+        int ne[],
+        float fmin,
+        float fmax) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
+
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return result;
+}
+
+float get_element(const struct ggml_tensor * t, int idx) {
+    return ((float *)t->data)[idx];
+}
+
+void set_element(struct ggml_tensor * t, int idx, float value) {
+    ((float *)t->data)[idx] = value;
+}
+
+bool check_gradient(
+        const char * op_name,
+        struct ggml_context * ctx0,
+        struct ggml_tensor * x[],
+        struct ggml_tensor * f,
+        int ndims,
+        int nargs,
+        float eps,
+        float max_error_abs,
+        float max_error_rel) {
+
+    struct ggml_cgraph gf = ggml_build_forward (f);
+    struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+
+    ggml_graph_compute(ctx0, &gf);
+    ggml_graph_reset  (&gf);
+    ggml_set_f32      (f->grad, 1.0f);
+    ggml_graph_compute(ctx0, &gb);
+
+    ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
+    ggml_graph_dump_dot(&gb, &gf,  "test-grad0-backward.dot");
+
+    for (int i = 0; i < nargs; ++i) {
+        const int nelements = ggml_nelements(x[i]);
+        for (int k = 0; k < nelements; ++k) {
+            // compute gradient using finite differences
+            const float x0 = get_element(x[i], k);
+
+            set_element(x[i], k, x0 + eps);
+            ggml_graph_compute(ctx0, &gf);
+
+            const float f0 = ggml_get_f32_1d(f, 0);
+
+            set_element(x[i], k, x0 - eps);
+            ggml_graph_compute(ctx0, &gf);
+
+            const float f1 = ggml_get_f32_1d(f, 0);
+
+            const float g0 = (f0 - f1)/(2.0f*eps);
+
+            set_element(x[i], k, x0);
+
+            // compute gradient using backward graph
+            ggml_graph_reset  (&gf);
+            ggml_set_f32      (f->grad, 1.0f);
+            ggml_graph_compute(ctx0, &gb);
+
+            const float g1 = get_element(x[i]->grad, k);
+
+            const float error_abs = fabsf(g0 - g1);
+            const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabs(g0) : 0;
+
+            if (error_abs > max_error_abs || error_rel > max_error_rel) {
+                printf("%s: ndims=%d, i=%d, k=%d, g0=%f, g1=%f, error_abs=%f, error_rel=%f\n",
+                        op_name, ndims, i, k, g0, g1, error_abs, error_rel);
+                assert(false);
+            }
+        }
+    }
+
+    return true;
+}
+
+// TODO: clean-up this ..
+bool check_mat_mul(
+        const struct ggml_tensor * y,
+        const struct ggml_tensor * x0,
+        const struct ggml_tensor * x1) {
+    float * dst  = (float *) y->data;
+    float * src0 = (float *) x0->data;
+    float * src1 = (float *) x1->data;
+
+    const int nc = x0->ne[1];
+    const int nr = x1->ne[1];
+    const int nk = x0->ne[0];
+
+    printf("check_mat_mul: nc=%d, nr=%d, nk=%d\n", nc, nr, nk);
+
+    printf("x0:\n");
+    for (int j = 0; j < x0->ne[1]; ++j) {
+        for (int i = 0; i < x0->ne[0]; ++i) {
+            printf("%6.3f ", src0[j*nk + i]);
+        }
+        printf("\n");
+    }
+    printf("\n");
+
+    printf("x1:\n");
+    for (int j = 0; j < x1->ne[1]; ++j) {
+        for (int i = 0; i < x1->ne[0]; ++i) {
+            printf("%6.3f ", src1[j*nk + i]);
+        }
+        printf("\n");
+    }
+    printf("\n");
+
+    printf("y: n_dims = %d, (%d, %d)\n", y->n_dims, y->ne[0], y->ne[1]);
+    for (int j = 0; j < y->ne[1]; ++j) {
+        for (int i = 0; i < y->ne[0]; ++i) {
+            printf("%6.3f ", dst[j*nr + i]);
+        }
+        printf("\n");
+    }
+
+    for (int i = 0; i < nr; ++i) {
+        for (int j = 0; j < nc; ++j) {
+            float sum = 0.0f;
+
+            for (int k = 0; k < nk; ++k) {
+                sum += src0[j*nk + k]*src1[i*nk + k];
+            }
+
+            if (fabsf(dst[i*nc + j] - sum) > 1e-5f) {
+                printf("check_mat_mul: dst[%d] = %f, sum = %f\n", i*nc + j, dst[i*nc + j], sum);
+                assert(false);
+                return false;
+            }
+        }
+    }
+
+    return true;
+}
+
+int main(int argc, const char ** argv) {
+    struct ggml_init_params params = {
+        .mem_size   = 128*1024*1024,
+        .mem_buffer = NULL,
+    };
+
+    int ne[4];
+
+    for (int iter = 0; iter < 1000; ++iter) {
+        struct ggml_context * ctx0 = ggml_init(params);
+
+        get_random_dims(ne, 4);
+
+        struct ggml_tensor * x[MAX_NARGS];
+
+        // add
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
+
+                check_gradient("add", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // sub
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1]));
+
+                check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // mul
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1]));
+
+                check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // div
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, 0.5f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1]));
+
+                check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-2f);
+            }
+        }
+
+        // sqr
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0]));
+
+                check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // sqrt
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0]));
+
+                check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f);
+            }
+        }
+
+        // sum
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, x[0]);
+
+                check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // abs (finite differences do not work)
+        //{
+        //    const int nargs = 1;
+
+        //    for (int ndims = 1; ndims <= 2; ++ndims) {
+        //        for (int i = 0; i < nargs; ++i) {
+        //            x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+        //            ggml_set_param(ctx0, x[i]);
+        //        }
+
+        //        struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0]));
+
+        //        check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f);
+        //    }
+        //}
+
+        // mul_mat
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                {
+                    int ne2[4];
+                    get_random_dims(ne2, 4);
+                    ne2[0] = ne[0];
+                    x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+                }
+
+                ggml_set_param(ctx0, x[0]);
+
+                struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
+                struct ggml_tensor * f = ggml_sum(ctx0, m);
+
+                printf("testing: mul_mat, [%d, %d] * [%d, %d]\n",
+                        x[1]->ne[0], x[1]->ne[1], x[0]->ne[0], x[0]->ne[1]);
+
+                check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_mat_mul(m, x[1], x[0]);
+            }
+        }
+
+        ggml_free(ctx0);
+    }
+
+    return 0;
+}
diff --git a/tests/test-mul-mat0.c b/tests/test-mul-mat0.c
new file mode 100644 (file)
index 0000000..1215c4c
--- /dev/null
@@ -0,0 +1,316 @@
+#include "ggml/ggml.h"
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+
+#define MAX_NARGS 2
+
+float frand() {
+    return (float)rand()/(float)RAND_MAX;
+}
+
+int irand(int n) {
+    return rand()%n;
+}
+
+void get_random_dims(int * dims, int ndims) {
+    dims[0] = dims[1] = dims[2] = dims[3] = 1;
+
+    for (int i = 0; i < ndims; i++) {
+        dims[i] = 1 + irand(4);
+    }
+}
+
+struct ggml_tensor * get_random_tensor(
+        struct ggml_context * ctx0,
+        int ndims,
+        int ne[],
+        float fmin,
+        float fmax) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
+
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return result;
+}
+
+float get_element(const struct ggml_tensor * t, int idx) {
+    return ((float *)t->data)[idx];
+}
+
+void set_element(struct ggml_tensor * t, int idx, float value) {
+    ((float *)t->data)[idx] = value;
+}
+
+bool check_gradient(
+        const char * op_name,
+        struct ggml_context * ctx0,
+        struct ggml_tensor * x[],
+        struct ggml_tensor * f,
+        int ndims,
+        int nargs,
+        float eps,
+        float max_error_abs,
+        float max_error_rel) {
+
+    struct ggml_cgraph gf = ggml_build_forward (f);
+    struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+
+    ggml_graph_compute(ctx0, &gf);
+    ggml_graph_reset  (&gf);
+    ggml_set_f32      (f->grad, 1.0f);
+    ggml_graph_compute(ctx0, &gb);
+
+    ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
+    ggml_graph_dump_dot(&gb, &gf,  "test-grad0-backward.dot");
+
+    for (int i = 0; i < nargs; ++i) {
+        const int nelements = ggml_nelements(x[i]);
+        for (int k = 0; k < nelements; ++k) {
+            // compute gradient using finite differences
+            const float x0 = get_element(x[i], k);
+
+            set_element(x[i], k, x0 + eps);
+            ggml_graph_compute(ctx0, &gf);
+
+            const float f0 = ggml_get_f32_1d(f, 0);
+
+            set_element(x[i], k, x0 - eps);
+            ggml_graph_compute(ctx0, &gf);
+
+            const float f1 = ggml_get_f32_1d(f, 0);
+
+            const float g0 = (f0 - f1)/(2.0f*eps);
+
+            set_element(x[i], k, x0);
+
+            // compute gradient using backward graph
+            ggml_graph_reset  (&gf);
+            ggml_set_f32      (f->grad, 1.0f);
+            ggml_graph_compute(ctx0, &gb);
+
+            const float g1 = get_element(x[i]->grad, k);
+
+            const float error_abs = fabsf(g0 - g1);
+            const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabs(g0) : 0;
+
+            if (error_abs > max_error_abs || error_rel > max_error_rel) {
+                printf("%s: ndims=%d, i=%d, k=%d, g0=%f, g1=%f, error_abs=%f, error_rel=%f\n",
+                        op_name, ndims, i, k, g0, g1, error_abs, error_rel);
+                assert(false);
+            }
+        }
+    }
+
+    return true;
+}
+
+
+float mat_get(const struct ggml_tensor * t, int i0, int i1, int i2, int i3) {
+    const size_t nb0 = t->nb[0];
+    const size_t nb1 = t->nb[1];
+    const size_t nb2 = t->nb[2];
+    const size_t nb3 = t->nb[3];
+
+    return
+        *((float*) ((char*)t->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3));
+}
+
+bool check_mat_mul(
+        const struct ggml_tensor * y,
+        const struct ggml_tensor * x0,
+        const struct ggml_tensor * x1) {
+    float * dst  = (float *) y->data;
+    float * src0 = (float *) x0->data;
+    float * src1 = (float *) x1->data;
+
+    const int n00 = x0->ne[0];
+    const int n10 = x0->ne[1];
+    const int n20 = x0->ne[2];
+    const int n30 = x0->ne[3];
+
+    const int n01 = x1->ne[0];
+    const int n11 = x1->ne[1];
+    const int n21 = x1->ne[2];
+    const int n31 = x1->ne[3];
+
+    const int n02 = y->ne[0];
+    const int n12 = y->ne[1];
+    const int n22 = y->ne[2];
+    const int n32 = y->ne[3];
+
+    printf("x0: [%d, %d, %d, %d]\n", n00, n10, n20, n30);
+    for (int j = 0; j < n10; ++j) {
+        for (int i = 0; i < n00; ++i) {
+            printf("%6.3f ", mat_get(x0, i, j, 0, 0));
+        }
+        printf("\n");
+    }
+    printf("\n");
+
+    printf("x1: [%d, %d, %d, %d]\n", n01, n11, n21, n31);
+    for (int j = 0; j < n11; ++j) {
+        for (int i = 0; i < n01; ++i) {
+            printf("%6.3f ", mat_get(x1, i, j, 0, 0));
+        }
+        printf("\n");
+    }
+    printf("\n");
+
+    printf("y: [%d, %d, %d, %d]\n", n02, n12, n22, n32);
+    for (int j = 0; j < n12; ++j) {
+        for (int i = 0; i < n02; ++i) {
+            printf("%6.3f ", mat_get(y, i, j, 0, 0));
+        }
+        printf("\n");
+    }
+
+    for (int i3 = 0; i3 < n32; ++i3) {
+        for (int i2 = 0; i2 < n22; ++i2) {
+            for (int i1 = 0; i1 < n12; ++i1) {
+                for (int i0 = 0; i0 < n02; ++i0) {
+                    float sum = 0.0f;
+                    for (int k = 0; k < n00; ++k) {
+                        sum += mat_get(x0, k, i0, i2, i3) * mat_get(x1, k, i1, i2, i3);
+                    }
+                    if (fabsf(sum - mat_get(y, i0, i1, i2, i3)) > 1e-5) {
+                        printf("error: i0=%d, i1=%d, i2=%d, i3=%d, sum=%f, y=%f\n",
+                                i0, i1, i2, i3, sum, mat_get(y, i0, i1, i2, i3));
+                        assert(false);
+                        return false;
+                    }
+                }
+            }
+        }
+    }
+
+    return true;
+}
+
+int main(int argc, const char ** argv) {
+    struct ggml_init_params params = {
+        .mem_size   = 128*1024*1024,
+        .mem_buffer = NULL,
+    };
+
+    int ne[4];
+
+    for (int iter = 0; iter < 500; ++iter) {
+        struct ggml_context * ctx0 = ggml_init(params);
+
+        get_random_dims(ne, 4);
+
+        struct ggml_tensor * x[MAX_NARGS];
+
+        // mul_mat
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                ne[1] = rand()%4 + 1;
+                x[1] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+
+                ggml_set_param(ctx0, x[0]);
+
+                struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
+                struct ggml_tensor * f = ggml_sum(ctx0, m);
+
+                printf("testing: mul_mat, [%d, %d, %d, %d] = [%d, %d, %d, %d] * [%d, %d, %d, %d]\n",
+                           m->ne[0],    m->ne[1],    m->ne[2],    m->ne[3],
+                        x[1]->ne[0], x[1]->ne[1], x[1]->ne[2], x[1]->ne[3],
+                        x[0]->ne[0], x[0]->ne[1], x[0]->ne[2], x[0]->ne[3]);
+
+                assert(m->ne[0] == x[1]->ne[1]);
+                assert(m->ne[1] == x[0]->ne[1]);
+                assert(m->ne[2] == x[0]->ne[2]);
+                assert(m->ne[3] == x[0]->ne[3]);
+
+                if (ndims <= 2) {
+                    check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                } else {
+                    struct ggml_cgraph gf = ggml_build_forward(m);
+                    ggml_graph_compute(ctx0, &gf);
+                }
+
+                check_mat_mul(m, x[1], x[0]);
+            }
+        }
+
+        // mul_mat (transposed)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 2; ndims <= 4; ++ndims) {
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                ne[1] = ne[0];
+                ne[0] = rand()%4 + 1;
+                x[1] = ggml_transpose(ctx0, get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f));
+
+                ggml_set_param(ctx0, x[0]);
+
+                struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
+                struct ggml_tensor * f = ggml_sum(ctx0, m);
+
+                printf("testing: mul_mat, [%d, %d, %d, %d] = [%d, %d, %d, %d] * [%d, %d, %d, %d]\n",
+                           m->ne[0],    m->ne[1],    m->ne[2],    m->ne[3],
+                        x[1]->ne[0], x[1]->ne[1], x[1]->ne[2], x[1]->ne[3],
+                        x[0]->ne[0], x[0]->ne[1], x[0]->ne[2], x[0]->ne[3]);
+
+                assert(m->ne[0] == x[1]->ne[1]);
+                assert(m->ne[1] == x[0]->ne[1]);
+                assert(m->ne[2] == x[0]->ne[2]);
+                assert(m->ne[3] == x[0]->ne[3]);
+
+                if (ndims <= 2) {
+                    check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                } else {
+                    struct ggml_cgraph gf = ggml_build_forward(m);
+                    ggml_graph_compute(ctx0, &gf);
+                }
+
+                check_mat_mul(m, x[1], x[0]);
+            }
+        }
+        ggml_free(ctx0);
+    }
+
+    return 0;
+}
diff --git a/tests/test-vec0.c b/tests/test-vec0.c
new file mode 100644 (file)
index 0000000..5e3bfbd
--- /dev/null
@@ -0,0 +1,124 @@
+#include <stdio.h>
+#include <assert.h>
+#include <stdlib.h>
+#include <time.h>
+
+const int N = 1 << 14;
+const int M = 1 << 14;
+
+void mul_mat_vec_f32_0(
+    const float * src0,
+    const float * src1,
+    float * dst,
+    unsigned nrows,
+    unsigned ncols) {
+    for (unsigned i = 0; i < nrows; i++) {
+        float sum = 0.0f;
+        for (unsigned j = 0; j < ncols; j++) {
+            sum += src0[i*ncols + j]*src1[j];
+        }
+        dst[i] = sum;
+    }
+}
+
+typedef float afloat __attribute__ ((__aligned__(32)));
+void mul_mat_vec_f32_1(
+    const afloat *restrict src0,
+    const afloat *restrict src1,
+    afloat *restrict dst,
+    unsigned nrows,
+    unsigned ncols) {
+    for (unsigned i = 0; i < nrows; i++) {
+        const afloat * restrict row = src0 + i*ncols;
+        const afloat * restrict col = src1;
+
+        float sum = 0.0f;
+
+        for (unsigned j = 0; j < ncols; j++) {
+            sum += *row++ * *col++;
+        }
+
+        dst[i] = sum;
+
+        //float sum[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
+
+        //for (unsigned j = 0; j < ncols; j += 8) {
+        //    sum[0] += row[0]*col[0];
+        //    sum[1] += row[1]*col[1];
+        //    sum[2] += row[2]*col[2];
+        //    sum[3] += row[3]*col[3];
+        //    sum[4] += row[4]*col[4];
+        //    sum[5] += row[5]*col[5];
+        //    sum[6] += row[6]*col[6];
+        //    sum[7] += row[7]*col[7];
+
+        //    row += 8;
+        //    col += 8;
+        //}
+
+        //dst[i] = sum[0] + sum[1] + sum[2] + sum[3] + sum[4] + sum[5] + sum[6] + sum[7];
+    }
+}
+
+void mul_mat_vec_f32_2(
+    const void * src0,
+    const void * src1,
+    void * dst,
+    unsigned nrows,
+    unsigned ncols) {
+    void * d = dst;
+    for (unsigned i = 0; i < nrows; i++) {
+        float sum = 0.0f;
+
+        const void * row = src0 + i*ncols*sizeof(float);
+        const void * col = src1;
+        for (unsigned j = 0; j < ncols; j++) {
+            sum += (*(float *)row) * (*(float *)col);
+            row += sizeof(float);
+            col += sizeof(float);
+        }
+        *(float *)d = sum;
+        d += sizeof(float);
+    }
+}
+
+int main(int argc, const char ** argv) {
+    //float * src0 = (float *)malloc(sizeof(float)*N*M);
+    //float * src1 = (float *)malloc(sizeof(float)*M);
+    //float * dst  = (float *)malloc(sizeof(float)*N);
+
+    afloat * src0 = (float *)(aligned_alloc(32, sizeof(float)*N*M));
+    afloat * src1 = (float *)(aligned_alloc(32, sizeof(float)*M));
+    afloat * dst  = (float *)(aligned_alloc(32, sizeof(float)*N));
+
+    for (unsigned i = 0; i < N*M; i++) {
+        src0[i] = i;
+    }
+
+    for (unsigned i = 0; i < M; i++) {
+        src1[i] = i;
+    }
+
+    const int nIter = 10;
+
+    const clock_t start = clock();
+
+    double sum = 0.0f;
+    for (int i = 0; i < nIter; i++) {
+        //mul_mat_vec_f32_0(src0, src1, dst, N, M);
+        mul_mat_vec_f32_1(src0, src1, dst, N, M);
+        //mul_mat_vec_f32_2(src0, src1, dst, N, M);
+        for (unsigned i = 0; i < N; i++) {
+            sum += dst[i];
+        }
+    }
+
+    {
+        const clock_t end = clock();
+        printf("%s: elapsed ticks: %ld\n", __func__, end - start);
+    }
+
+    printf("%f\n", sum);
+
+    return 0;
+}
diff --git a/tests/test-vec1.c b/tests/test-vec1.c
new file mode 100644 (file)
index 0000000..850c622
--- /dev/null
@@ -0,0 +1,546 @@
+#include <stdint.h>
+#include <stdio.h>
+#include <assert.h>
+#include <stdlib.h>
+#include <time.h>
+#include <math.h>
+
+#include <sys/time.h>
+
+#include <immintrin.h>
+
+const int N = 1 << 14;
+const int M = 768;
+
+//
+// naive implementation
+//
+
+void mul_mat_vec_f32_0(
+    const float * restrict src0,
+    const float * restrict src1,
+    float * dst,
+    int nrows,
+    int ncols) {
+    for (int i = 0; i < nrows; i++) {
+        float sum = 0.0f;
+        for (int j = 0; j < ncols; j++) {
+            sum += src0[i*ncols + j]*src1[j];
+        }
+        dst[i] = sum;
+    }
+}
+
+//
+// SIMD with 8 32-bit floats
+//
+
+float reduce_vector8_0(__m256 v) {
+    __m128 v1 = _mm256_extractf128_ps(v, 0);
+    __m128 v2 = _mm256_extractf128_ps(v, 1);
+    __m128 v3 = _mm_add_ps(v1, v2);
+    __m128 v4 = _mm_shuffle_ps(v3, v3, 0x4e);
+    __m128 v5 = _mm_add_ps(v3, v4);
+    __m128 v6 = _mm_shuffle_ps(v5, v5, 0x11);
+    __m128 v7 = _mm_add_ps(v5, v6);
+    return _mm_cvtss_f32(v7);
+}
+
+// vectorized implementation using AVX
+void mul_mat_vec_f32_1(
+    const float * restrict src0,
+    const float * restrict src1,
+    float * dst,
+    int nrows,
+    int ncols) {
+
+    const int ncols8 = ncols & ~7;
+
+    for (int i = 0; i < nrows; i++) {
+        __m256 sum = _mm256_setzero_ps();
+        for (int j = 0; j < ncols8; j += 8) {
+            __m256 a = _mm256_loadu_ps(src0 + i*ncols + j);
+            __m256 b = _mm256_loadu_ps(src1 + j);
+            __m256 c = _mm256_mul_ps(a, b);
+            sum = _mm256_add_ps(sum, c);
+        }
+        dst[i] = reduce_vector8_0(sum);
+
+        for (int j = ncols8; j < ncols; j++) {
+            dst[i] += src0[i*ncols + j]*src1[j];
+        }
+    }
+}
+
+void mul_mat_vec_f32_2(
+    const float * restrict src0,
+    const float * restrict src1,
+    float * dst,
+    int nrows,
+    int ncols) {
+
+    const int ncols32 = ncols & ~31;
+
+    for (int i = 0; i < nrows; i++) {
+        __m256 sum0 = _mm256_setzero_ps();
+        __m256 sum1 = _mm256_setzero_ps();
+        __m256 sum2 = _mm256_setzero_ps();
+        __m256 sum3 = _mm256_setzero_ps();
+
+        const float * restrict src0_row = src0 + i*ncols;
+        for (int j = 0; j < ncols32; j += 32) {
+            __m256 a0 = _mm256_loadu_ps(src0_row + j + 0);
+            __m256 a1 = _mm256_loadu_ps(src0_row + j + 8);
+            __m256 a2 = _mm256_loadu_ps(src0_row + j + 16);
+            __m256 a3 = _mm256_loadu_ps(src0_row + j + 24);
+            __m256 b0 = _mm256_loadu_ps(src1 + j + 0);
+            __m256 b1 = _mm256_loadu_ps(src1 + j + 8);
+            __m256 b2 = _mm256_loadu_ps(src1 + j + 16);
+            __m256 b3 = _mm256_loadu_ps(src1 + j + 24);
+            sum0 = _mm256_fmadd_ps(a0, b0, sum0);
+            sum1 = _mm256_fmadd_ps(a1, b1, sum1);
+            sum2 = _mm256_fmadd_ps(a2, b2, sum2);
+            sum3 = _mm256_fmadd_ps(a3, b3, sum3);
+        }
+        dst[i] = reduce_vector8_0(_mm256_add_ps(_mm256_add_ps(sum0, sum1), _mm256_add_ps(sum2, sum3)));
+
+        for (int j = ncols32; j < ncols; j++) {
+            dst[i] += src0[i*ncols + j]*src1[j];
+        }
+    }
+}
+
+//
+// SIMD with 8 16-bit floats
+//
+
+static inline float fp32_from_bits(uint32_t w) {
+#if defined(__OPENCL_VERSION__)
+    return as_float(w);
+#elif defined(__CUDA_ARCH__)
+    return __uint_as_float((unsigned int) w);
+#elif defined(__INTEL_COMPILER)
+    return _castu32_f32(w);
+#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64))
+    return _CopyFloatFromInt32((__int32) w);
+#else
+    union {
+        uint32_t as_bits;
+        float as_value;
+    } fp32 = { w };
+    return fp32.as_value;
+#endif
+}
+
+static inline uint32_t fp32_to_bits(float f) {
+#if defined(__OPENCL_VERSION__)
+       return as_uint(f);
+#elif defined(__CUDA_ARCH__)
+       return (uint32_t) __float_as_uint(f);
+#elif defined(__INTEL_COMPILER)
+       return _castf32_u32(f);
+#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64))
+       return (uint32_t) _CopyInt32FromFloat(f);
+#else
+       union {
+               float as_value;
+               uint32_t as_bits;
+       } fp32 = { f };
+       return fp32.as_bits;
+#endif
+}
+
+/*
+ * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to
+ * a 32-bit floating-point number in IEEE single-precision format.
+ *
+ * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals)
+ * floating-point operations and bitcasts between integer and floating-point variables.
+ */
+static inline float fp16_ieee_to_fp32_value(uint16_t h) {
+    /*
+     * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word:
+     *      +---+-----+------------+-------------------+
+     *      | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
+     *      +---+-----+------------+-------------------+
+     * Bits  31  26-30    16-25            0-15
+     *
+     * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits.
+     */
+    const uint32_t w = (uint32_t) h << 16;
+    /*
+     * Extract the sign of the input number into the high bit of the 32-bit word:
+     *
+     *      +---+----------------------------------+
+     *      | S |0000000 00000000 00000000 00000000|
+     *      +---+----------------------------------+
+     * Bits  31                 0-31
+     */
+    const uint32_t sign = w & UINT32_C(0x80000000);
+    /*
+     * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word:
+     *
+     *      +-----+------------+---------------------+
+     *      |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
+     *      +-----+------------+---------------------+
+     * Bits  27-31    17-26            0-16
+     */
+    const uint32_t two_w = w + w;
+
+    /*
+     * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent
+     * of a single-precision floating-point number:
+     *
+     *       S|Exponent |          Mantissa
+     *      +-+---+-----+------------+----------------+
+     *      |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
+     *      +-+---+-----+------------+----------------+
+     * Bits   | 23-31   |           0-22
+     *
+     * Next, there are some adjustments to the exponent:
+     * - The exponent needs to be corrected by the difference in exponent bias between single-precision and half-precision
+     *   formats (0x7F - 0xF = 0x70)
+     * - Inf and NaN values in the inputs should become Inf and NaN values after conversion to the single-precision number.
+     *   Therefore, if the biased exponent of the half-precision input was 0x1F (max possible value), the biased exponent
+     *   of the single-precision output must be 0xFF (max possible value). We do this correction in two steps:
+     *   - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset below) rather than by 0x70 suggested
+     *     by the difference in the exponent bias (see above).
+     *   - Then we multiply the single-precision result of exponent adjustment by 2**(-112) to reverse the effect of
+     *     exponent adjustment by 0xE0 less the necessary exponent adjustment by 0x70 due to difference in exponent bias.
+     *     The floating-point multiplication hardware would ensure than Inf and NaN would retain their value on at least
+     *     partially IEEE754-compliant implementations.
+     *
+     * Note that the above operations do not handle denormal inputs (where biased exponent == 0). However, they also do not
+     * operate on denormal inputs, and do not produce denormal results.
+     */
+    const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+    const float exp_scale = 0x1.0p-112f;
+#else
+    const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
+#endif
+    const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+    /*
+     * Convert denormalized half-precision inputs into single-precision results (always normalized).
+     * Zero inputs are also handled here.
+     *
+     * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits.
+     * First, we shift mantissa into bits 0-9 of the 32-bit word.
+     *
+     *                  zeros           |  mantissa
+     *      +---------------------------+------------+
+     *      |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
+     *      +---------------------------+------------+
+     * Bits             10-31                0-9
+     *
+     * Now, remember that denormalized half-precision numbers are represented as:
+     *    FP16 = mantissa * 2**(-24).
+     * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input
+     * and with an exponent which would scale the corresponding mantissa bits to 2**(-24).
+     * A normalized single-precision floating-point number is represented as:
+     *    FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127)
+     * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision
+     * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount.
+     *
+     * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number
+     * is zero, the constructed single-precision number has the value of
+     *    FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5
+     * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of
+     * the input half-precision number.
+     */
+    const uint32_t magic_mask = UINT32_C(126) << 23;
+    const float magic_bias = 0.5f;
+    const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+    /*
+     * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the
+     *   input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the
+     *   input is either a denormal number, or zero.
+     * - Combine the result of conversion of exponent and mantissa with the sign of the input number.
+     */
+    const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+    const uint32_t result = sign |
+        (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
+    return fp32_from_bits(result);
+}
+
+/*
+ * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in
+ * IEEE half-precision format, in bit representation.
+ *
+ * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals)
+ * floating-point operations and bitcasts between integer and floating-point variables.
+ */
+static inline uint16_t fp16_ieee_from_fp32_value(float f) {
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+    const float scale_to_inf = 0x1.0p+112f;
+    const float scale_to_zero = 0x1.0p-110f;
+#else
+    const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
+    const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
+#endif
+    float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+    const uint32_t w = fp32_to_bits(f);
+    const uint32_t shl1_w = w + w;
+    const uint32_t sign = w & UINT32_C(0x80000000);
+    uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+    if (bias < UINT32_C(0x71000000)) {
+        bias = UINT32_C(0x71000000);
+    }
+
+    base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+    const uint32_t bits = fp32_to_bits(base);
+    const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+    const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+    const uint32_t nonsign = exp_bits + mantissa_bits;
+    return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
+}
+
+void mul_mat_vec_f16_0(
+    const uint16_t * src0,
+    const uint16_t * src1,
+             float * dst,
+    int nrows,
+    int ncols) {
+
+    const int ncols8 = ncols & ~7;
+
+    for (int i = 0; i < nrows; i++) {
+        __m256 sum = _mm256_setzero_ps();
+
+        const uint16_t * src0_row = src0 + i * ncols;
+        for (int j = 0; j < ncols8; j += 8) {
+            __m256 a = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src0_row + j)));
+            __m256 b = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src1 + j)));
+            sum = _mm256_fmadd_ps(a, b, sum);
+        }
+        dst[i] = reduce_vector8_0(sum);
+
+        for (int j = ncols8; j < ncols; j++) {
+            dst[i] += fp16_ieee_to_fp32_value(src0_row[j]) * fp16_ieee_to_fp32_value(src1[j]);
+        }
+    }
+}
+
+void mul_mat_vec_f16_1(
+    const uint16_t * src0,
+    const uint16_t * src1,
+             float * dst,
+    int nrows,
+    int ncols) {
+
+    const int ncols16 = ncols & ~15;
+
+    for (int i = 0; i < nrows; i++) {
+        __m256 sum0 = _mm256_setzero_ps();
+        __m256 sum1 = _mm256_setzero_ps();
+
+        const uint16_t * src0_row = src0 + i * ncols;
+        for (int j = 0; j < ncols16; j += 16) {
+            __m256 a0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src0_row + j + 0)));
+            __m256 a1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src0_row + j + 8)));
+            __m256 b0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src1 + j)));
+            __m256 b1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src1 + j + 8)));
+            sum0 = _mm256_fmadd_ps(a0, b0, sum0);
+            sum1 = _mm256_fmadd_ps(a1, b1, sum1);
+        }
+        dst[i] = reduce_vector8_0(sum0) + reduce_vector8_0(sum1);
+
+        for (int j = ncols16; j < ncols; j++) {
+            dst[i] += fp16_ieee_to_fp32_value(src0_row[j]) * fp16_ieee_to_fp32_value(src1[j]);
+        }
+    }
+}
+
+void mul_mat_vec_f16_2(
+    const uint16_t * src0,
+    const uint16_t * src1,
+             float * dst,
+    int nrows,
+    int ncols) {
+
+    const int ncols32 = ncols & ~31;
+
+    for (int i = 0; i < nrows; i++) {
+        __m256 sum0 = _mm256_setzero_ps();
+        __m256 sum1 = _mm256_setzero_ps();
+        __m256 sum2 = _mm256_setzero_ps();
+        __m256 sum3 = _mm256_setzero_ps();
+
+        const uint16_t * src0_row = src0 + i * ncols;
+        for (int j = 0; j < ncols32; j += 32) {
+            __m256 a0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src0_row + j + 0)));
+            __m256 a1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src0_row + j + 8)));
+            __m256 a2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src0_row + j + 16)));
+            __m256 a3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src0_row + j + 24)));
+            __m256 b0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src1 + j)));
+            __m256 b1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src1 + j + 8)));
+            __m256 b2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src1 + j + 16)));
+            __m256 b3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src1 + j + 24)));
+            sum0 = _mm256_fmadd_ps(a0, b0, sum0);
+            sum1 = _mm256_fmadd_ps(a1, b1, sum1);
+            sum2 = _mm256_fmadd_ps(a2, b2, sum2);
+            sum3 = _mm256_fmadd_ps(a3, b3, sum3);
+        }
+        dst[i] = reduce_vector8_0(sum0) + reduce_vector8_0(sum1) + reduce_vector8_0(sum2) + reduce_vector8_0(sum3);
+
+        for (int j = ncols32; j < ncols; j++) {
+            dst[i] += fp16_ieee_to_fp32_value(src0_row[j]) * fp16_ieee_to_fp32_value(src1[j]);
+        }
+    }
+}
+
+void mul_mat_vec_f16_3(
+    const uint16_t * src0,
+    const    float * src1,
+             float * dst,
+    int nrows,
+    int ncols) {
+
+    const int ncols32 = ncols & ~31;
+
+    for (int i = 0; i < nrows; i++) {
+        __m256 sum0 = _mm256_setzero_ps();
+        __m256 sum1 = _mm256_setzero_ps();
+        __m256 sum2 = _mm256_setzero_ps();
+        __m256 sum3 = _mm256_setzero_ps();
+
+        const uint16_t * src0_row = src0 + i * ncols;
+        for (int j = 0; j < ncols32; j += 32) {
+            __m256 a0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src0_row + j + 0)));
+            __m256 a1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src0_row + j + 8)));
+            __m256 a2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src0_row + j + 16)));
+            __m256 a3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src0_row + j + 24)));
+            __m256 b0 = _mm256_loadu_ps(src1 + j);
+            __m256 b1 = _mm256_loadu_ps(src1 + j + 8);
+            __m256 b2 = _mm256_loadu_ps(src1 + j + 16);
+            __m256 b3 = _mm256_loadu_ps(src1 + j + 24);
+            sum0 = _mm256_fmadd_ps(a0, b0, sum0);
+            sum1 = _mm256_fmadd_ps(a1, b1, sum1);
+            sum2 = _mm256_fmadd_ps(a2, b2, sum2);
+            sum3 = _mm256_fmadd_ps(a3, b3, sum3);
+        }
+        dst[i] = reduce_vector8_0(sum0) + reduce_vector8_0(sum1) + reduce_vector8_0(sum2) + reduce_vector8_0(sum3);
+
+        for (int j = ncols32; j < ncols; j++) {
+            dst[i] += fp16_ieee_to_fp32_value(src0_row[j]) * fp16_ieee_to_fp32_value(src1[j]);
+        }
+    }
+}
+
+uint64_t get_time_us() {
+    struct timeval tv;
+    gettimeofday(&tv, NULL);
+    return tv.tv_sec * 1000000 + tv.tv_usec;
+}
+
+int main(int argc, const char ** argv) {
+    float * src0 = (float *)malloc(sizeof(float)*N*M);
+    float * src1 = (float *)malloc(sizeof(float)*M);
+    float * dst  = (float *)malloc(sizeof(float)*N);
+
+    //float * src0 = (float *)(aligned_alloc(64, sizeof(float)*N*M));
+    //float * src1 = (float *)(aligned_alloc(64, sizeof(float)*M));
+    //float * dst  = (float *)(aligned_alloc(64, sizeof(float)*N));
+
+    for (int i = 0; i < N*M; i++) {
+        src0[i] = rand() / (float)RAND_MAX;
+    }
+
+    for (int i = 0; i < M; i++) {
+        src1[i] = rand() / (float)RAND_MAX;
+    }
+
+    // convert src0 and src1 to __fp16
+    uint16_t * src0_fp16 = (uint16_t *)(malloc(sizeof(uint16_t)*N*M));
+    uint16_t * src1_fp16 = (uint16_t *)(malloc(sizeof(uint16_t)*M));
+    //uint16_t * src0_fp16 = (uint16_t *)(aligned_alloc(64, sizeof(uint16_t)*N*M));
+    //uint16_t * src1_fp16 = (uint16_t *)(aligned_alloc(64, sizeof(uint16_t)*M));
+
+    {
+        const uint64_t t_start = get_time_us();
+
+        for (int i = 0; i < N*M; i++) {
+            src0_fp16[i] = fp16_ieee_from_fp32_value(src0[i]);
+            //printf("%f %f\n", src0[i], fp16_ieee_to_fp32_value(src0_fp16[i]));
+            //assert(!isnan(fp16_ieee_to_fp32_value(src0_fp16[i])));
+        }
+
+        for (int i = 0; i < M; i++) {
+            src1_fp16[i] = fp16_ieee_from_fp32_value(src1[i]);
+        }
+
+        const uint64_t t_end = get_time_us();
+        printf("convert time: %f ms\n", (t_end - t_start) / 1000.0);
+    }
+
+    for (int i = 0; i < 16; ++i) {
+        printf("%f %f\n", src0[i], fp16_ieee_to_fp32_value(src0_fp16[i]));
+    }
+
+    int method = 0;
+    if (argc > 1) {
+        method = atoi(argv[1]);
+    }
+
+    const int nIter = 1000;
+
+    const clock_t start = clock();
+    const uint64_t start_us = get_time_us();
+
+    double iM = 1.0/M;
+    double sum = 0.0f;
+    for (int i = 0; i < nIter; i++) {
+        if (method == 0) {
+            mul_mat_vec_f32_0(src0, src1, dst, N, M);
+        }
+
+        if (method == 1) {
+            mul_mat_vec_f32_1(src0, src1, dst, N, M);
+        }
+
+        if (method == 2) {
+            mul_mat_vec_f32_2(src0, src1, dst, N, M);
+        }
+
+        if (method == 3) {
+            mul_mat_vec_f16_0(src0_fp16, src1_fp16, dst, N, M);
+        }
+
+        if (method == 4) {
+            mul_mat_vec_f16_1(src0_fp16, src1_fp16, dst, N, M);
+        }
+
+        if (method == 5) {
+            mul_mat_vec_f16_2(src0_fp16, src1_fp16, dst, N, M);
+        }
+
+        if (method == 6) {
+            mul_mat_vec_f16_3(src0_fp16, src1, dst, N, M);
+        }
+    }
+
+    for (int i = 0; i < N; i++) {
+        sum += dst[i]*iM;
+    }
+
+    {
+        const clock_t end = clock();
+        const uint64_t end_us = get_time_us();
+        printf("%s: elapsed ticks: %ld\n", __func__, end - start);
+        printf("%s: elapsed us: %ld\n", __func__, end_us - start_us);
+    }
+
+    printf("%f\n", sum);
+
+    free(src0);
+    free(src1);
+    free(dst);
+
+    free(src0_fp16);
+    free(src1_fp16);
+
+    return 0;
+}
diff --git a/tests/test-vec2.c b/tests/test-vec2.c
new file mode 100644 (file)
index 0000000..44dfd4f
--- /dev/null
@@ -0,0 +1,200 @@
+#include <stdint.h>
+#include <stdio.h>
+#include <assert.h>
+#include <stdlib.h>
+#include <time.h>
+#include <math.h>
+
+#include <sys/time.h>
+
+#include <arm_neon.h>
+
+const int N = 1 << 14;
+const int M = 768;
+
+//
+// naive implementation
+//
+
+void mul_mat_vec_f32_0(
+    const float * restrict src0,
+    const float * restrict src1,
+    float * dst,
+    int nrows,
+    int ncols) {
+    for (int i = 0; i < nrows; i++) {
+        float sum = 0.0f;
+        for (int j = 0; j < ncols; j++) {
+            sum += src0[i*ncols + j]*src1[j];
+        }
+        dst[i] = sum;
+    }
+}
+
+void mul_mat_vec_f16_0(
+    const __fp16 * src0,
+    const __fp16 * src1,
+           float * dst,
+    int nrows,
+    int ncols) {
+
+    const int n64 = ncols & ~63;
+
+    for (int r = 0; r < nrows; r++) {
+        float sumf = 0.0;
+
+        float16x8_t sum0 = vdupq_n_f16(0.0f);
+        float16x8_t sum1 = vdupq_n_f16(0.0f);
+        float16x8_t sum2 = vdupq_n_f16(0.0f);
+        float16x8_t sum3 = vdupq_n_f16(0.0f);
+        float16x8_t sum4 = vdupq_n_f16(0.0f);
+        float16x8_t sum5 = vdupq_n_f16(0.0f);
+        float16x8_t sum6 = vdupq_n_f16(0.0f);
+        float16x8_t sum7 = vdupq_n_f16(0.0f);
+
+        float16x8_t x0, x1, x2, x3, x4, x5, x6, x7;
+        float16x8_t y0, y1, y2, y3, y4, y5, y6, y7;
+
+        const __fp16 * restrict p0 = src0 + r*ncols;
+
+        for (int i = 0; i < n64; i += 64) {
+            x0 = vld1q_f16(p0 + i + 0 );
+            x1 = vld1q_f16(p0 + i + 8 );
+            x2 = vld1q_f16(p0 + i + 16);
+            x3 = vld1q_f16(p0 + i + 24);
+            x4 = vld1q_f16(p0 + i + 32);
+            x5 = vld1q_f16(p0 + i + 40);
+            x6 = vld1q_f16(p0 + i + 48);
+            x7 = vld1q_f16(p0 + i + 56);
+
+            y0 = vld1q_f16(src1 + i + 0 );
+            y1 = vld1q_f16(src1 + i + 8 );
+            y2 = vld1q_f16(src1 + i + 16);
+            y3 = vld1q_f16(src1 + i + 24);
+            y4 = vld1q_f16(src1 + i + 32);
+            y5 = vld1q_f16(src1 + i + 40);
+            y6 = vld1q_f16(src1 + i + 48);
+            y7 = vld1q_f16(src1 + i + 56);
+
+            sum0 = vfmaq_f16(sum0, x0, y0);
+            sum1 = vfmaq_f16(sum1, x1, y1);
+            sum2 = vfmaq_f16(sum2, x2, y2);
+            sum3 = vfmaq_f16(sum3, x3, y3);
+            sum4 = vfmaq_f16(sum4, x4, y4);
+            sum5 = vfmaq_f16(sum5, x5, y5);
+            sum6 = vfmaq_f16(sum6, x6, y6);
+            sum7 = vfmaq_f16(sum7, x7, y7);
+        }
+
+        // TODO: F16 - better way to reduce this ?
+        float16x8_t sum = vaddq_f16(sum0, sum1);
+
+        sum = vaddq_f16(sum, sum2);
+        sum = vaddq_f16(sum, sum3);
+        sum = vaddq_f16(sum, sum4);
+        sum = vaddq_f16(sum, sum5);
+        sum = vaddq_f16(sum, sum6);
+        sum = vaddq_f16(sum, sum7);
+
+        sumf += sum[0] + sum[1] + sum[2] + sum[3] + sum[4] + sum[5] + sum[6] + sum[7];
+
+        for (int j = n64; j < n64; j++) {
+            sumf += src0[r*ncols + j]*src1[j];
+        }
+
+        dst[r] = sumf;
+    }
+}
+
+uint64_t get_time_us() {
+    struct timeval tv;
+    gettimeofday(&tv, NULL);
+    return tv.tv_sec * 1000000 + tv.tv_usec;
+}
+
+int main(int argc, const char ** argv) {
+    float * src0 = (float *)malloc(sizeof(float)*N*M);
+    float * src1 = (float *)malloc(sizeof(float)*M);
+    float * dst  = (float *)malloc(sizeof(float)*N);
+
+    //float * src0 = (float *)(aligned_alloc(64, sizeof(float)*N*M));
+    //float * src1 = (float *)(aligned_alloc(64, sizeof(float)*M));
+    //float * dst  = (float *)(aligned_alloc(64, sizeof(float)*N));
+
+    for (int i = 0; i < N*M; i++) {
+        src0[i] = rand() / (float)RAND_MAX;
+    }
+
+    for (int i = 0; i < M; i++) {
+        src1[i] = rand() / (float)RAND_MAX;
+    }
+
+    // convert src0 and src1 to __fp16
+    __fp16 * src0_fp16 = (__fp16 *)(malloc(sizeof(__fp16)*N*M));
+    __fp16 * src1_fp16 = (__fp16 *)(malloc(sizeof(__fp16)*M));
+
+    {
+        const uint64_t t_start = get_time_us();
+
+        for (int i = 0; i < N*M; i++) {
+            src0_fp16[i] = src0[i];
+            //printf("%f %f\n", src0[i], src0_fp16[i]);
+            //assert(!isnan(src0_fp16[i]));
+        }
+
+        for (int i = 0; i < M; i++) {
+            src1_fp16[i] = src1[i];
+        }
+
+        const uint64_t t_end = get_time_us();
+        printf("convert time: %f ms\n", (t_end - t_start) / 1000.0);
+    }
+
+    for (int i = 0; i < 16; ++i) {
+        printf("%f %f\n", src0[i], src0_fp16[i]);
+    }
+
+    int method = 0;
+    if (argc > 1) {
+        method = atoi(argv[1]);
+    }
+
+    const int nIter = 1000;
+
+    const clock_t start = clock();
+    const uint64_t start_us = get_time_us();
+
+    double iM = 1.0/M;
+    double sum = 0.0f;
+    for (int i = 0; i < nIter; i++) {
+        if (method == 0) {
+            mul_mat_vec_f32_0(src0, src1, dst, N, M);
+        }
+
+        if (method == 1) {
+            mul_mat_vec_f16_0(src0_fp16, src1_fp16, dst, N, M);
+        }
+    }
+
+    for (int i = 0; i < N; i++) {
+        sum += dst[i]*iM;
+    }
+
+    {
+        const clock_t end = clock();
+        const uint64_t end_us = get_time_us();
+        printf("%s: elapsed ticks: %ld\n", __func__, end - start);
+        printf("%s: elapsed us: %llu\n", __func__, end_us - start_us);
+    }
+
+    printf("%f\n", sum);
+
+    free(src0);
+    free(src1);
+    free(dst);
+
+    free(src0_fp16);
+    free(src1_fp16);
+
+    return 0;
+}
diff --git a/tests/test0.c b/tests/test0.c
new file mode 100644 (file)
index 0000000..b9cb5fd
--- /dev/null
@@ -0,0 +1,42 @@
+#include "ggml/ggml.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+
+int main(int argc, const char ** argv) {
+    struct ggml_init_params params = {
+        .mem_size   = 128*1024*1024,
+        .mem_buffer = NULL,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    struct ggml_tensor * t1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 10);
+    struct ggml_tensor * t2 = ggml_new_tensor_2d(ctx0, GGML_TYPE_I16, 10, 20);
+    struct ggml_tensor * t3 = ggml_new_tensor_3d(ctx0, GGML_TYPE_I32, 10, 20, 30);
+
+    assert(t1->n_dims == 1);
+    assert(t1->ne[0]  == 10);
+    assert(t1->nb[1]  == 10*sizeof(float));
+
+    assert(t2->n_dims == 2);
+    assert(t2->ne[0]  == 10);
+    assert(t2->ne[1]  == 20);
+    assert(t2->nb[1]  == 10*sizeof(int16_t));
+    assert(t2->nb[2]  == 10*20*sizeof(int16_t));
+
+    assert(t3->n_dims == 3);
+    assert(t3->ne[0]  == 10);
+    assert(t3->ne[1]  == 20);
+    assert(t3->ne[2]  == 30);
+    assert(t3->nb[1]  == 10*sizeof(int32_t));
+    assert(t3->nb[2]  == 10*20*sizeof(int32_t));
+    assert(t3->nb[3]  == 10*20*30*sizeof(int32_t));
+
+    ggml_print_objects(ctx0);
+
+    ggml_free(ctx0);
+
+    return 0;
+}
diff --git a/tests/test1.c b/tests/test1.c
new file mode 100644 (file)
index 0000000..c9b5921
--- /dev/null
@@ -0,0 +1,436 @@
+#include "ggml/ggml.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+
+int main(int argc, const char ** argv) {
+    struct ggml_init_params params = {
+        .mem_size   = 128*1024*1024,
+        .mem_buffer = NULL,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    {
+        struct ggml_tensor * x = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+
+        ggml_set_param(ctx0, x);
+
+        struct ggml_tensor * a = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        struct ggml_tensor * b = ggml_mul(ctx0, x, x);
+        struct ggml_tensor * f = ggml_mul(ctx0, b, a);
+
+        // a*x^2
+        // 2*a*x
+
+        ggml_print_objects(ctx0);
+
+        struct ggml_cgraph gf = ggml_build_forward(f);
+        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+
+        ggml_set_f32(x, 2.0f);
+        ggml_set_f32(a, 3.0f);
+
+        ggml_graph_reset(&gf);
+        ggml_set_f32(f->grad, 1.0f);
+
+        ggml_graph_compute(ctx0, &gb);
+
+        printf("f     = %f\n", ggml_get_f32_1d(f, 0));
+        printf("df/dx = %f\n", ggml_get_f32_1d(x->grad, 0));
+
+        assert(ggml_get_f32_1d(f, 0)       == 12.0f);
+        assert(ggml_get_f32_1d(x->grad, 0) == 12.0f);
+
+        ggml_set_f32(x, 3.0f);
+
+        ggml_graph_reset(&gf);
+        ggml_set_f32(f->grad, 1.0f);
+
+        ggml_graph_compute(ctx0, &gb);
+
+        printf("f     = %f\n", ggml_get_f32_1d(f, 0));
+        printf("df/dx = %f\n", ggml_get_f32_1d(x->grad, 0));
+
+        assert(ggml_get_f32_1d(f, 0)       == 27.0f);
+        assert(ggml_get_f32_1d(x->grad, 0) == 18.0f);
+
+        ggml_graph_dump_dot(&gf, NULL, "test1-1-forward.dot");
+        ggml_graph_dump_dot(&gb, &gf,  "test1-1-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        struct ggml_tensor * x3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+
+        ggml_set_f32(x1, 3.0f);
+        ggml_set_f32(x2, 1.0f);
+        ggml_set_f32(x3, 0.0f);
+
+        ggml_set_param(ctx0, x1);
+        ggml_set_param(ctx0, x2);
+
+        struct ggml_tensor * y = ggml_add(ctx0, ggml_mul(ctx0, x1, x1), ggml_mul(ctx0, x1, x2));
+
+        struct ggml_cgraph gf = ggml_build_forward(y);
+        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+
+        ggml_graph_reset(&gf);
+        ggml_set_f32(y->grad, 1.0f);
+
+        ggml_graph_compute(ctx0, &gb);
+
+        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
+        printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
+        printf("df/dx2 = %f\n", ggml_get_f32_1d(x2->grad, 0));
+
+        assert(ggml_get_f32_1d(y, 0)        == 12.0f);
+        assert(ggml_get_f32_1d(x1->grad, 0) == 7.0f);
+        assert(ggml_get_f32_1d(x2->grad, 0) == 3.0f);
+
+        struct ggml_tensor * g1 = x1->grad;
+        struct ggml_tensor * g2 = x2->grad;
+
+        struct ggml_cgraph gbb = ggml_build_backward(ctx0, &gb, true);
+
+        ggml_graph_reset(&gb);
+        ggml_set_f32(g1->grad, 1.0f);
+        ggml_set_f32(g2->grad, 1.0f);
+
+        ggml_graph_compute(ctx0, &gbb);
+
+        printf("H * [1, 1] = [ %f %f ]\n", ggml_get_f32_1d(x1->grad, 0), ggml_get_f32_1d(x2->grad, 0));
+
+        assert(ggml_get_f32_1d(x1->grad, 0) == 3.0f);
+        assert(ggml_get_f32_1d(x2->grad, 0) == 1.0f);
+
+        ggml_graph_dump_dot(&gf, NULL, "test1-2-forward.dot");
+        ggml_graph_dump_dot(&gb, &gf,  "test1-2-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+
+        ggml_set_param(ctx0, x1);
+        ggml_set_param(ctx0, x2);
+
+        struct ggml_tensor * y = ggml_mul(ctx0, ggml_add(ctx0, ggml_mul(ctx0, x1, x1), ggml_mul(ctx0, x1, x2)), x1);
+
+        struct ggml_cgraph gf = ggml_build_forward(y);
+        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+
+        ggml_set_f32(x1, 3.0f);
+        ggml_set_f32(x2, 4.0f);
+
+        ggml_graph_reset(&gf);
+        ggml_set_f32(y->grad, 1.0f);
+
+        ggml_graph_compute(ctx0, &gb);
+
+        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
+        printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
+        printf("df/dx2 = %f\n", ggml_get_f32_1d(x2->grad, 0));
+
+        assert(ggml_get_f32_1d(y, 0)        == 63.0f);
+        assert(ggml_get_f32_1d(x1->grad, 0) == 51.0f);
+        assert(ggml_get_f32_1d(x2->grad, 0) == 9.0f);
+
+        ggml_graph_dump_dot(&gf, NULL, "test1-3-forward.dot");
+        ggml_graph_dump_dot(&gb, &gf,  "test1-3-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        struct ggml_tensor * x3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+
+        ggml_set_param(ctx0, x1);
+        ggml_set_param(ctx0, x2);
+        ggml_set_param(ctx0, x3);
+
+        struct ggml_tensor * y = ggml_mul(ctx0, ggml_mul(ctx0, ggml_mul(ctx0, x1, x1), ggml_mul(ctx0, x2, x2)), x3);
+
+        struct ggml_cgraph gf = ggml_build_forward(y);
+        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+
+        ggml_set_f32(x1, 1.0f);
+        ggml_set_f32(x2, 2.0f);
+        ggml_set_f32(x3, 3.0f);
+
+        ggml_graph_reset(&gf);
+        ggml_set_f32(y->grad, 1.0f);
+
+        ggml_graph_compute(ctx0, &gb);
+
+        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
+        printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
+        printf("df/dx2 = %f\n", ggml_get_f32_1d(x2->grad, 0));
+        printf("df/dx3 = %f\n", ggml_get_f32_1d(x3->grad, 0));
+
+        assert(ggml_get_f32_1d(y, 0)        == 12.0f);
+        assert(ggml_get_f32_1d(x1->grad, 0) == 24.0f);
+        assert(ggml_get_f32_1d(x2->grad, 0) == 12.0f);
+        assert(ggml_get_f32_1d(x3->grad, 0) == 4.0f);
+
+        struct ggml_tensor * g1 = x1->grad;
+        struct ggml_tensor * g2 = x2->grad;
+        struct ggml_tensor * g3 = x3->grad;
+
+        struct ggml_cgraph gbb = ggml_build_backward(ctx0, &gb, true);
+
+        ggml_graph_reset(&gb);
+        ggml_set_f32(g1->grad, 1.0f);
+        ggml_set_f32(g2->grad, 1.0f);
+        ggml_set_f32(g3->grad, 1.0f);
+
+        ggml_graph_compute(ctx0, &gbb);
+
+        printf("H * [1, 1, 1] = [ %f %f %f ]\n",
+                ggml_get_f32_1d(x1->grad, 0),
+                ggml_get_f32_1d(x2->grad, 0),
+                ggml_get_f32_1d(x3->grad, 0));
+
+        assert(ggml_get_f32_1d(x1->grad, 0) == 56.0f);
+        assert(ggml_get_f32_1d(x2->grad, 0) == 34.0f);
+        assert(ggml_get_f32_1d(x3->grad, 0) == 12.0f);
+
+        ggml_graph_dump_dot(&gf, NULL, "test1-4-forward.dot");
+        ggml_graph_dump_dot(&gb, &gf,  "test1-4-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
+        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
+
+        ggml_set_param(ctx0, x1);
+        ggml_set_param(ctx0, x2);
+
+        struct ggml_tensor * y = ggml_sum(ctx0, ggml_mul(ctx0, x1, x2));
+
+        struct ggml_cgraph gf = ggml_build_forward(y);
+        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+
+        ggml_set_f32(x1, 3.0f);
+        ggml_set_f32(x2, 5.0f);
+
+        ggml_graph_reset(&gf);
+        ggml_set_f32(y->grad, 1.0f);
+
+        ggml_graph_compute(ctx0, &gb);
+
+        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
+        printf("df/dx1 = %f %f %f\n",
+                ggml_get_f32_1d(x1->grad, 0),
+                ggml_get_f32_1d(x1->grad, 1),
+                ggml_get_f32_1d(x1->grad, 2));
+        printf("df/dx2 = %f %f %f\n",
+                ggml_get_f32_1d(x2->grad, 0),
+                ggml_get_f32_1d(x2->grad, 1),
+                ggml_get_f32_1d(x2->grad, 2));
+
+        assert(ggml_get_f32_1d(y, 0)        == 45.0f);
+        assert(ggml_get_f32_1d(x1->grad, 0) == 5.0f);
+        assert(ggml_get_f32_1d(x2->grad, 0) == 3.0f);
+        assert(ggml_get_f32_1d(x1->grad, 1) == 5.0f);
+        assert(ggml_get_f32_1d(x2->grad, 1) == 3.0f);
+        assert(ggml_get_f32_1d(x1->grad, 2) == 5.0f);
+        assert(ggml_get_f32_1d(x2->grad, 2) == 3.0f);
+
+        ggml_graph_dump_dot(&gf, NULL, "test1-5-forward.dot");
+        ggml_graph_dump_dot(&gb, &gf,  "test1-5-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
+        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
+
+        ggml_set_param(ctx0, x1);
+        ggml_set_param(ctx0, x2);
+
+        struct ggml_tensor * y =
+            ggml_sum(ctx0,
+                    ggml_add(ctx0,
+                        ggml_mul(ctx0, x1, x2),
+                        ggml_mul(ctx0,
+                            ggml_repeat(ctx0, ggml_new_f32(ctx0, -2.0f), x1),
+                            ggml_mul(ctx0, x1, x1)
+                            )
+                        )
+                    );
+
+        struct ggml_cgraph gf = ggml_build_forward(y);
+        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+
+        ggml_set_f32(x1, 3.0f);
+        ggml_set_f32(x2, 5.0f);
+
+        ggml_graph_reset(&gf);
+        ggml_set_f32(y->grad, 1.0f);
+
+        ggml_graph_compute(ctx0, &gb);
+
+        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
+        printf("df/dx1 = %f %f %f\n",
+                ggml_get_f32_1d(x1->grad, 0),
+                ggml_get_f32_1d(x1->grad, 1),
+                ggml_get_f32_1d(x1->grad, 2));
+        printf("df/dx2 = %f %f %f\n",
+                ggml_get_f32_1d(x2->grad, 0),
+                ggml_get_f32_1d(x2->grad, 1),
+                ggml_get_f32_1d(x2->grad, 2));
+
+        assert(ggml_get_f32_1d(y, 0)              == -9.0f);
+        assert(ggml_get_f32_1d(x1->grad, 0) == -7.0f);
+        assert(ggml_get_f32_1d(x1->grad, 1) == -7.0f);
+        assert(ggml_get_f32_1d(x1->grad, 2) == -7.0f);
+        assert(ggml_get_f32_1d(x2->grad, 0) ==  3.0f);
+        assert(ggml_get_f32_1d(x2->grad, 1) ==  3.0f);
+        assert(ggml_get_f32_1d(x2->grad, 2) ==  3.0f);
+
+        ggml_graph_dump_dot(&gf, NULL, "test1-6-forward.dot");
+        ggml_graph_dump_dot(&gb, &gf,  "test1-6-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
+        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
+
+        ggml_set_param(ctx0, x1);
+        ggml_set_param(ctx0, x2);
+
+        struct ggml_tensor * y =
+            ggml_sum(ctx0,
+                    ggml_sub(ctx0,
+                        ggml_mul(ctx0, x1, x2),
+                        ggml_mul(ctx0,
+                            ggml_mul(ctx0, x1, x1),
+                            ggml_repeat(ctx0, ggml_new_f32(ctx0, -2.0f), x1)
+                            )
+                        )
+                    );
+
+        struct ggml_cgraph gf = ggml_build_forward(y);
+        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+
+        ggml_set_f32(x1, 3.0f);
+        ggml_set_f32(x2, 5.0f);
+
+        ggml_graph_reset(&gf);
+        ggml_set_f32(y->grad, 1.0f);
+
+        ggml_graph_compute(ctx0, &gb);
+
+        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
+        printf("df/dx1 = %f %f %f\n",
+                ggml_get_f32_1d(x1->grad, 0),
+                ggml_get_f32_1d(x1->grad, 1),
+                ggml_get_f32_1d(x1->grad, 2));
+        printf("df/dx2 = %f %f %f\n",
+                ggml_get_f32_1d(x2->grad, 0),
+                ggml_get_f32_1d(x2->grad, 1),
+                ggml_get_f32_1d(x2->grad, 2));
+
+        assert(ggml_get_f32_1d(y, 0)        == 99.0f);
+        assert(ggml_get_f32_1d(x1->grad, 0) == 17.0f);
+        assert(ggml_get_f32_1d(x1->grad, 1) == 17.0f);
+        assert(ggml_get_f32_1d(x1->grad, 2) == 17.0f);
+        assert(ggml_get_f32_1d(x2->grad, 0) ==  3.0f);
+        assert(ggml_get_f32_1d(x2->grad, 1) ==  3.0f);
+        assert(ggml_get_f32_1d(x2->grad, 2) ==  3.0f);
+
+        ggml_graph_dump_dot(&gf, NULL, "test1-7-forward.dot");
+        ggml_graph_dump_dot(&gb, &gf,  "test1-7-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
+        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
+
+        ggml_set_param(ctx0, x1);
+        ggml_set_param(ctx0, x2);
+
+        struct ggml_tensor * y =
+            ggml_abs(ctx0,
+                    ggml_sub(ctx0, x1, x2)
+                    );
+
+        struct ggml_cgraph gf = ggml_build_forward(y);
+        struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+
+        ggml_set_f32(x1, 3.0f);
+        ggml_set_f32(x2, 5.0f);
+
+        ggml_graph_reset(&gf);
+        ggml_set_f32(y->grad, 1.0f);
+
+        ggml_graph_compute(ctx0, &gb);
+
+        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
+        printf("df/dx1 = %f %f %f\n",
+                ggml_get_f32_1d(x1->grad, 0),
+                ggml_get_f32_1d(x1->grad, 1),
+                ggml_get_f32_1d(x1->grad, 2));
+        printf("df/dx2 = %f %f %f\n",
+                ggml_get_f32_1d(x2->grad, 0),
+                ggml_get_f32_1d(x2->grad, 1),
+                ggml_get_f32_1d(x2->grad, 2));
+
+        assert(ggml_get_f32_1d(y, 0)        ==  2.0f);
+        assert(ggml_get_f32_1d(x1->grad, 0) == -1.0f);
+        assert(ggml_get_f32_1d(x1->grad, 1) == -1.0f);
+        assert(ggml_get_f32_1d(x1->grad, 2) == -1.0f);
+        assert(ggml_get_f32_1d(x2->grad, 0) ==  1.0f);
+        assert(ggml_get_f32_1d(x2->grad, 1) ==  1.0f);
+        assert(ggml_get_f32_1d(x2->grad, 2) ==  1.0f);
+
+        ggml_set_f32(x1, 7.0f);
+        ggml_set_f32(x2, 5.0f);
+
+        ggml_graph_reset(&gf);
+        ggml_set_f32(y->grad, 1.0f);
+
+        ggml_graph_compute(ctx0, &gb);
+
+        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
+        printf("df/dx1 = %f %f %f\n",
+                ggml_get_f32_1d(x1->grad, 0),
+                ggml_get_f32_1d(x1->grad, 1),
+                ggml_get_f32_1d(x1->grad, 2));
+        printf("df/dx2 = %f %f %f\n",
+                ggml_get_f32_1d(x2->grad, 0),
+                ggml_get_f32_1d(x2->grad, 1),
+                ggml_get_f32_1d(x2->grad, 2));
+
+        assert(ggml_get_f32_1d(y, 0)        ==  2.0f);
+        assert(ggml_get_f32_1d(x1->grad, 0) ==  1.0f);
+        assert(ggml_get_f32_1d(x1->grad, 1) ==  1.0f);
+        assert(ggml_get_f32_1d(x1->grad, 2) ==  1.0f);
+        assert(ggml_get_f32_1d(x2->grad, 0) == -1.0f);
+        assert(ggml_get_f32_1d(x2->grad, 1) == -1.0f);
+        assert(ggml_get_f32_1d(x2->grad, 2) == -1.0f);
+
+        ggml_graph_dump_dot(&gf, NULL, "test1-8-forward.dot");
+        ggml_graph_dump_dot(&gb, &gf,  "test1-8-backward.dot");
+    }
+
+    ggml_free(ctx0);
+
+    return 0;
+}
diff --git a/tests/test2.c b/tests/test2.c
new file mode 100644 (file)
index 0000000..8fd5418
--- /dev/null
@@ -0,0 +1,166 @@
+#include "ggml/ggml.h"
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+
+bool is_close(float a, float b, float epsilon) {
+    return fabs(a - b) < epsilon;
+}
+
+int main(int argc, const char ** argv) {
+    struct ggml_init_params params = {
+        .mem_size   = 128*1024*1024,
+        .mem_buffer = NULL,
+    };
+
+    //struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_LBFGS);
+
+    struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_ADAM);
+    opt_params.adam.alpha = 0.01f;
+
+    opt_params.n_threads = (argc > 1) ? atoi(argv[1]) : 8;
+
+    const float xi[] = {  1.0f,  2.0f,  3.0f,  4.0f,  5.0f , 6.0f,  7.0f,  8.0f,  9.0f,  10.0f, };
+          float yi[] = { 15.0f, 25.0f, 35.0f, 45.0f, 55.0f, 65.0f, 75.0f, 85.0f, 95.0f, 105.0f, };
+
+    const int n = sizeof(xi)/sizeof(xi[0]);
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    struct ggml_tensor * x = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n);
+    struct ggml_tensor * y = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n);
+
+    for (int i = 0; i < n; i++) {
+        ((float *) x->data)[i] = xi[i];
+        ((float *) y->data)[i] = yi[i];
+    }
+
+    {
+        struct ggml_tensor * t0 = ggml_new_f32(ctx0, 0.0f);
+        struct ggml_tensor * t1 = ggml_new_f32(ctx0, 0.0f);
+
+        // initialize auto-diff parameters:
+        ggml_set_param(ctx0, t0);
+        ggml_set_param(ctx0, t1);
+
+        // f = sum_i[(t0 + t1*x_i - y_i)^2]/(2n)
+        struct ggml_tensor * f =
+            ggml_div(ctx0,
+                    ggml_sum(ctx0,
+                        ggml_sqr(ctx0,
+                            ggml_sub(ctx0,
+                                ggml_add(ctx0,
+                                    ggml_mul(ctx0, x, ggml_repeat(ctx0, t1, x)),
+                                    ggml_repeat(ctx0, t0, x)),
+                                y)
+                            )
+                        ),
+                    ggml_new_f32(ctx0, 2.0f*n));
+
+        enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
+
+        assert(res == GGML_OPT_OK);
+
+        printf("t0 = %f\n", ggml_get_f32_1d(t0, 0));
+        printf("t1 = %f\n", ggml_get_f32_1d(t1, 0));
+
+        assert(is_close(ggml_get_f32_1d(t0, 0),  5.0f, 1e-3f));
+        assert(is_close(ggml_get_f32_1d(t1, 0), 10.0f, 1e-3f));
+    }
+
+    {
+        struct ggml_tensor * t0 = ggml_new_f32(ctx0, -1.0f);
+        struct ggml_tensor * t1 = ggml_new_f32(ctx0,  9.0f);
+
+        ggml_set_param(ctx0, t0);
+        ggml_set_param(ctx0, t1);
+
+        // f = 0.5*sum_i[abs(t0 + t1*x_i - y_i)]/n
+        struct ggml_tensor * f =
+            ggml_mul(ctx0,
+                    ggml_new_f32(ctx0, 1.0/(2*n)),
+                    ggml_sum(ctx0,
+                        ggml_abs(ctx0,
+                            ggml_sub(ctx0,
+                                ggml_add(ctx0,
+                                    ggml_mul(ctx0, x, ggml_repeat(ctx0, t1, x)),
+                                    ggml_repeat(ctx0, t0, x)),
+                                y)
+                            )
+                        )
+                    );
+
+
+        enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
+
+        assert(res == GGML_OPT_OK);
+        assert(is_close(ggml_get_f32_1d(t0, 0),  5.0f, 1e-3f));
+        assert(is_close(ggml_get_f32_1d(t1, 0), 10.0f, 1e-3f));
+    }
+
+    {
+        struct ggml_tensor * t0 = ggml_new_f32(ctx0,  5.0f);
+        struct ggml_tensor * t1 = ggml_new_f32(ctx0, -4.0f);
+
+        ggml_set_param(ctx0, t0);
+        ggml_set_param(ctx0, t1);
+
+        // f = t0^2 + t1^2
+        struct ggml_tensor * f =
+            ggml_add(ctx0,
+                    ggml_sqr(ctx0, t0),
+                    ggml_sqr(ctx0, t1)
+                    );
+
+        enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
+
+        assert(res == GGML_OPT_OK);
+        assert(is_close(ggml_get_f32_1d(f,  0), 0.0f, 1e-3f));
+        assert(is_close(ggml_get_f32_1d(t0, 0), 0.0f, 1e-3f));
+        assert(is_close(ggml_get_f32_1d(t1, 0), 0.0f, 1e-3f));
+    }
+
+    /////////////////////////////////////////
+
+    {
+        struct ggml_tensor * t0 = ggml_new_f32(ctx0, -7.0f);
+        struct ggml_tensor * t1 = ggml_new_f32(ctx0,  8.0f);
+
+        ggml_set_param(ctx0, t0);
+        ggml_set_param(ctx0, t1);
+
+        // f = (t0 + 2*t1 - 7)^2 + (2*t0 + t1 - 5)^2
+        struct ggml_tensor * f =
+            ggml_add(ctx0,
+                    ggml_sqr(ctx0,
+                        ggml_sub(ctx0,
+                            ggml_add(ctx0,
+                                t0,
+                                ggml_mul(ctx0, t1, ggml_new_f32(ctx0, 2.0f))),
+                            ggml_new_f32(ctx0, 7.0f)
+                            )
+                        ),
+                    ggml_sqr(ctx0,
+                        ggml_sub(ctx0,
+                            ggml_add(ctx0,
+                                ggml_mul(ctx0, t0, ggml_new_f32(ctx0, 2.0f)),
+                                t1),
+                            ggml_new_f32(ctx0, 5.0f)
+                            )
+                        )
+                    );
+
+        enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
+
+        assert(res == GGML_OPT_OK);
+        assert(is_close(ggml_get_f32_1d(f,  0), 0.0f, 1e-3f));
+        assert(is_close(ggml_get_f32_1d(t0, 0), 1.0f, 1e-3f));
+        assert(is_close(ggml_get_f32_1d(t1, 0), 3.0f, 1e-3f));
+    }
+
+    ggml_free(ctx0);
+
+    return 0;
+}
diff --git a/tests/test3.c b/tests/test3.c
new file mode 100644 (file)
index 0000000..8210a56
--- /dev/null
@@ -0,0 +1,95 @@
+#include "ggml/ggml.h"
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+
+bool is_close(float a, float b, float epsilon) {
+    return fabs(a - b) < epsilon;
+}
+
+int main(int argc, const char ** argv) {
+    struct ggml_init_params params = {
+        .mem_size   = 1024*1024*1024,
+        .mem_buffer = NULL,
+    };
+
+    struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_LBFGS);
+    //struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_ADAM);
+
+    opt_params.n_threads = (argc > 1) ? atoi(argv[1]) : 8;
+
+    const int NP = 1 << 12;
+    const int NF = 1 << 8;
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    struct ggml_tensor * F = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, NF, NP);
+    struct ggml_tensor * l = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, NP);
+
+    // regularization weight
+    struct ggml_tensor * lambda = ggml_new_f32(ctx0, 1e-5f);
+
+    srand(0);
+
+    for (int j = 0; j < NP; j++) {
+        const float ll = j < NP/2 ? 1.0f : -1.0f;
+        ((float *)l->data)[j] = ll;
+
+        for (int i = 0; i < NF; i++) {
+            ((float *)F->data)[j*NF + i] = ((ll > 0 && i < NF/2 ? 1.0f : ll < 0 && i >= NF/2 ? 1.0f : 0.0f) + ((float)rand()/(float)RAND_MAX - 0.5f)*0.1f)/(0.5f*NF);
+        }
+    }
+
+    {
+        // initial guess
+        struct ggml_tensor * x = ggml_set_f32(ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, NF), 0.0f);
+
+        ggml_set_param(ctx0, x);
+
+        // f = sum[(fj*x - l)^2]/n + lambda*|x^2|
+        struct ggml_tensor * f =
+            ggml_add(ctx0,
+                    ggml_div(ctx0,
+                        ggml_sum(ctx0,
+                            ggml_sqr(ctx0,
+                                ggml_sub(ctx0,
+                                    ggml_mul_mat(ctx0, F, x),
+                                    l)
+                                )
+                            ),
+                        ggml_new_f32(ctx0, NP)
+                        ),
+                    ggml_mul(ctx0,
+                        ggml_sum(ctx0, ggml_sqr(ctx0, x)),
+                        lambda)
+                    );
+
+        enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
+
+        assert(res == GGML_OPT_OK);
+
+        // print results
+        for (int i = 0; i < 16; i++) {
+            printf("x[%3d] = %g\n", i, ((float *)x->data)[i]);
+        }
+        printf("...\n");
+        for (int i = NF - 16; i < NF; i++) {
+            printf("x[%3d] = %g\n", i, ((float *)x->data)[i]);
+        }
+        printf("\n");
+
+        for (int i = 0; i < NF; ++i) {
+            if (i < NF/2) {
+                assert(is_close(((float *)x->data)[i],  1.0f, 1e-2f));
+            } else {
+                assert(is_close(((float *)x->data)[i], -1.0f, 1e-2f));
+            }
+        }
+    }
+
+    ggml_free(ctx0);
+
+    return 0;
+}