]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : Metal inference (#1642)
authorGeorgi Gerganov <redacted>
Sun, 4 Jun 2023 20:34:30 +0000 (23:34 +0300)
committerGitHub <redacted>
Sun, 4 Jun 2023 20:34:30 +0000 (23:34 +0300)
* mtl : export the LLaMA computation graph

* ci : disable temporary

* mtl : adapt the MNIST example as starter

* mtl : no need for mtl-export tool, add cli arg for main instead

* mtl : export just a small part of the graph for now to make it easier

* mtl : move MSL code into separate file for easy editing

* mtl : initial get_rows_q4_0 kernel

* mtl : confirmed get_rows_q4_0 is working correctly

* mtl : add rms_norm kernel + confirm working

* mtl : add mul kernel + confirm working

* mtl : initial mul_mat Q4 kernel (wrong results)

* mtl : mul_mat fixes (still wrong)

* mtl : another mul_mat Q4 (still does not work)

* mtl : working mul_mat q4

* ggml : fix handling of "view" ops in ggml_graph_import()

* mtl : add rope kernel

* mtl : add reshape and transpose handling

* ggml : store offset as opt arg for ggml_view_xd() operators

* mtl : add cpy kernel + handle view ops

* mtl : confirm f16 x f32 attention mul mat

* mtl : add scale kernel

* mtl : add diag_mask_inf kernel

* mtl : fix soft_max kernel

* ggml : update ggml_nbytes() to handle non-contiguous tensors

* mtl : verify V tensor contents

* mtl : add f32 -> f32 cpy kernel

* mtl : add silu kernel

* mtl : add non-broadcast mul kernel

* mtl : full GPU inference of the computation graph

* mtl : optimize rms_norm and soft_max kernels

* mtl : add f16 mat x f32 vec multiplication kernel

* mtl : fix bug in f16 x f32 mul mat + speed-up computation

* mtl : faster mul_mat_q4_0_f32 kernel

* mtl : fix kernel signature + roll inner loop

* mtl : more threads for rms_norm + better timing

* mtl : remove printfs from inner loop

* mtl : simplify implementation

* mtl : add save/load vocab to ggml file

* mtl : plug Metal inference into llama.cpp (very quick-n-dirty)

* mtl : make it work with main example

Lots of hacks but at least now it generates text

* mtl : preparing for merge

* mtl : clean-up ggml mtl interface + suport scratch / inplace

* mtl : remove temp / debug code

* metal : final refactoring and simplification

* Revert "ci : disable temporary"

This reverts commit 98c267fc77fe811082f672538fc91bcfc9072d63.

* metal : add comments

* metal : clean-up stuff, fix typos

* readme : add Metal instructions

* readme : add example for main

17 files changed:
.gitignore
CMakeLists.txt
Makefile
README.md
examples/CMakeLists.txt
examples/common.cpp
examples/common.h
examples/main/main.cpp
examples/metal/CMakeLists.txt [new file with mode: 0644]
examples/metal/metal.cpp [new file with mode: 0644]
ggml-metal.h [new file with mode: 0644]
ggml-metal.m [new file with mode: 0644]
ggml-metal.metal [new file with mode: 0644]
ggml.c
ggml.h
llama.cpp
llama.h

index d231f3ff8ed3636891e4c09706b97d09a85ca94a..e4561ad7344c269fb632f8dd1a524a7c997a32be 100644 (file)
@@ -17,6 +17,7 @@ build-release/
 build-static/
 build-cublas/
 build-opencl/
+build-metal/
 build-no-accel/
 build-sanitize-addr/
 build-sanitize-thread/
index 21f4ec9ddd2676b75adf68620a1cce3c1bf2568e..1f2e78c0ffba418877355451e5b5a5026d09e49d 100644 (file)
@@ -64,13 +64,14 @@ if (NOT MSVC)
 endif()
 
 # 3rd party libs
-option(LLAMA_ACCELERATE                 "llama: enable Accelerate framework"                    ON)
-option(LLAMA_BLAS                       "llama: use BLAS"                                       OFF)
+option(LLAMA_ACCELERATE                      "llama: enable Accelerate framework"               ON)
+option(LLAMA_BLAS                            "llama: use BLAS"                                  OFF)
 set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
-option(LLAMA_CUBLAS                     "llama: use cuBLAS"                                     OFF)
-set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
-set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING  "llama: y block size for dmmv CUDA kernels")
-option(LLAMA_CLBLAST                    "llama: use CLBlast"                                    OFF)
+option(LLAMA_CUBLAS                          "llama: use cuBLAS"                                OFF)
+set(LLAMA_CUDA_DMMV_X      "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
+set(LLAMA_CUDA_DMMV_Y       "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
+option(LLAMA_CLBLAST                         "llama: use CLBlast"                               OFF)
+option(LLAMA_METAL                           "llama: use Metal"                                 OFF)
 
 option(LLAMA_BUILD_TESTS                "llama: build tests"    ${LLAMA_STANDALONE})
 option(LLAMA_BUILD_EXAMPLES             "llama: build examples" ${LLAMA_STANDALONE})
@@ -183,7 +184,7 @@ if (LLAMA_CUBLAS)
 
         enable_language(CUDA)
 
-        set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
+        set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
 
         add_compile_definitions(GGML_USE_CUBLAS)
         add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
@@ -200,12 +201,37 @@ if (LLAMA_CUBLAS)
     endif()
 endif()
 
+if (LLAMA_METAL)
+    find_library(FOUNDATION_LIBRARY         Foundation              REQUIRED)
+    find_library(METAL_FRAMEWORK            Metal                   REQUIRED)
+    find_library(METALKIT_FRAMEWORK         MetalKit                REQUIRED)
+    find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
+
+    set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
+
+    add_compile_definitions(GGML_USE_METAL)
+    add_compile_definitions(GGML_METAL_NDEBUG)
+
+    # get full path to the file
+    #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/")
+
+    # copy ggml-metal.metal to bin directory
+    configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY)
+
+    set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS}
+        ${FOUNDATION_LIBRARY}
+        ${METAL_FRAMEWORK}
+        ${METALKIT_FRAMEWORK}
+        ${METALPERFORMANCE_FRAMEWORK}
+        )
+endif()
+
 if (LLAMA_CLBLAST)
     find_package(CLBlast)
     if (CLBlast_FOUND)
         message(STATUS "CLBlast found")
 
-        set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h)
+        set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h)
 
         add_compile_definitions(GGML_USE_CLBLAST)
 
@@ -370,8 +396,10 @@ endif()
 add_library(ggml OBJECT
             ggml.c
             ggml.h
-            ${GGML_CUDA_SOURCES}
-            ${GGML_OPENCL_SOURCES})
+            ${GGML_SOURCES_CUDA}
+            ${GGML_SOURCES_OPENCL}
+            ${GGML_SOURCES_METAL}
+            )
 
 target_include_directories(ggml PUBLIC .)
 target_compile_features(ggml PUBLIC c_std_11) # don't bump
@@ -384,21 +412,25 @@ endif()
 add_library(llama
             llama.cpp
             llama.h
-            llama-util.h)
+            llama-util.h
+            )
 
 target_include_directories(llama PUBLIC .)
 target_compile_features(llama PUBLIC cxx_std_11) # don't bump
-target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS})
+target_link_libraries(llama PRIVATE
+    ggml
+    ${LLAMA_EXTRA_LIBS}
+    )
 
 if (BUILD_SHARED_LIBS)
     set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)
     target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD)
 endif()
 
-if (GGML_CUDA_SOURCES)
+if (GGML_SOURCES_CUDA)
     message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
-    set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF)
-    set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
+    set_property(TARGET ggml  PROPERTY CUDA_ARCHITECTURES OFF)
+    set_property(TARGET ggml  PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
     set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES OFF)
 endif()
 
index 8e8d426c5d6bf5b32b74c6657d36b25f5bd9008d..1f910c3ec862967bd7ce5e0e95a5e7735ce491c2 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -105,6 +105,7 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
        #CFLAGS   += -mfma -mf16c -mavx
        #CXXFLAGS += -mfma -mf16c -mavx
 endif
+
 ifneq ($(filter ppc64%,$(UNAME_M)),)
        POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
        ifneq (,$(findstring POWER9,$(POWER9_M)))
@@ -116,6 +117,7 @@ ifneq ($(filter ppc64%,$(UNAME_M)),)
                CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN
        endif
 endif
+
 ifndef LLAMA_NO_ACCELERATE
        # Mac M1 - include Accelerate framework.
        # `-framework Accelerate` works on Mac Intel as well, with negliable performance boost (as of the predict time).
@@ -123,7 +125,8 @@ ifndef LLAMA_NO_ACCELERATE
                CFLAGS  += -DGGML_USE_ACCELERATE
                LDFLAGS += -framework Accelerate
        endif
-endif
+endif # LLAMA_NO_ACCELERATE
+
 ifdef LLAMA_OPENBLAS
        CFLAGS  += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
        ifneq ($(shell grep -e "Arch Linux" -e "ID_LIKE=arch" /etc/os-release 2>/dev/null),)
@@ -131,11 +134,13 @@ ifdef LLAMA_OPENBLAS
        else
                LDFLAGS += -lopenblas
        endif
-endif
+endif # LLAMA_OPENBLAS
+
 ifdef LLAMA_BLIS
        CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/blis -I/usr/include/blis
        LDFLAGS += -lblis -L/usr/local/lib
-endif
+endif # LLAMA_BLIS
+
 ifdef LLAMA_CUBLAS
        CFLAGS    += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
        CXXFLAGS  += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
@@ -156,9 +161,10 @@ endif # LLAMA_CUDA_DMMV_Y
 ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
        $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
 endif # LLAMA_CUBLAS
+
 ifdef LLAMA_CLBLAST
-       CFLAGS  += -DGGML_USE_CLBLAST
-       CXXFLAGS  += -DGGML_USE_CLBLAST
+       CFLAGS   += -DGGML_USE_CLBLAST
+       CXXFLAGS += -DGGML_USE_CLBLAST
        # Mac provides OpenCL as a framework
        ifeq ($(UNAME_S),Darwin)
                LDFLAGS += -lclblast -framework OpenCL
@@ -166,23 +172,38 @@ ifdef LLAMA_CLBLAST
                LDFLAGS += -lclblast -lOpenCL
        endif
        OBJS    += ggml-opencl.o
+
 ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h
        $(CXX) $(CXXFLAGS) -c $< -o $@
-endif
+endif # LLAMA_CLBLAST
+
+ifdef LLAMA_METAL
+       CFLAGS   += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
+       CXXFLAGS += -DGGML_USE_METAL
+       LDFLAGS  += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
+       OBJS     += ggml-metal.o
+
+ggml-metal.o: ggml-metal.m ggml-metal.h
+       $(CC) $(CFLAGS) -c $< -o $@
+endif # LLAMA_METAL
+
 ifneq ($(filter aarch64%,$(UNAME_M)),)
        # Apple M1, M2, etc.
        # Raspberry Pi 3, 4, Zero 2 (64-bit)
        CFLAGS   += -mcpu=native
        CXXFLAGS += -mcpu=native
 endif
+
 ifneq ($(filter armv6%,$(UNAME_M)),)
        # Raspberry Pi 1, Zero
        CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
 endif
+
 ifneq ($(filter armv7%,$(UNAME_M)),)
        # Raspberry Pi 2
        CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
 endif
+
 ifneq ($(filter armv8%,$(UNAME_M)),)
        # Raspberry Pi 3, 4, Zero 2 (32-bit)
        CFLAGS += -mfp16-format=ieee -mno-unaligned-access
index aba22b9a069ca3c9dde1ec66603606eefa63cd4b..4fc877c64ae5a556e2a39360974729baa99ad6f1 100644 (file)
--- a/README.md
+++ b/README.md
@@ -51,11 +51,10 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
 The main goal of `llama.cpp` is to run the LLaMA model using 4-bit integer quantization on a MacBook
 
 - Plain C/C++ implementation without dependencies
-- Apple silicon first-class citizen - optimized via ARM NEON and Accelerate framework
+- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
 - AVX, AVX2 and AVX512 support for x86 architectures
 - Mixed F16 / F32 precision
 - 4-bit, 5-bit and 8-bit integer quantization support
-- Runs on the CPU
 - Supports OpenBLAS/Apple BLAS/ARM Performance Lib/ATLAS/BLIS/Intel MKL/NVHPC/ACML/SCSL/SGIMATH and [more](https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors) in BLAS
 - cuBLAS and CLBlast support
 
@@ -236,6 +235,32 @@ In order to build llama.cpp you have three different options.
     zig build -Drelease-fast
     ```
 
+### Metal Build
+
+Using Metal allows the computation to be executed on the GPU for Apple devices:
+
+- Using `make`:
+
+  ```bash
+  LLAMA_METAL=1 make
+  ```
+
+- Using `CMake`:
+
+    ```bash
+    mkdir build-metal
+    cd build-metal
+    cmake -DLLAMA_METAL=ON ..
+    cmake --build . --config Release
+    ```
+
+When built with Metal support, you can enable GPU inference with the `--gpu-layers|-ngl` command-line argument.
+Any value larger than 0 will offload the computation to the GPU. For example:
+
+```bash
+./main -m ./models/7B/ggml-model-q4_0.bin -n 128 -ngl 1
+```
+
 ### BLAS Build
 
 Building the program with BLAS support may lead to some performance improvements in prompt processing using batch sizes higher than 32 (the default is 512). BLAS doesn't affect the normal generation performance. There are currently three different implementations of it:
@@ -369,7 +394,7 @@ Building the program with BLAS support may lead to some performance improvements
 
   Running:
 
-  The CLBlast build supports `--gpu-layers|-ngl` like  the CUDA version does.
+  The CLBlast build supports `--gpu-layers|-ngl` like the CUDA version does.
 
   To select the correct platform (driver) and device (GPU), you can use the environment variables `GGML_OPENCL_PLATFORM` and `GGML_OPENCL_DEVICE`.
   The selection can be a number (starting from 0) or a text string to search:
index e4ce5aca7b98b869bd94a7475226b0c1a90e1b9c..3deff4077f80e980bb91770075e726359059fa0d 100644 (file)
@@ -37,7 +37,10 @@ else()
     add_subdirectory(save-load-state)
     add_subdirectory(benchmark)
     add_subdirectory(baby-llama)
-    if(LLAMA_BUILD_SERVER)
+    if (LLAMA_METAL)
+        add_subdirectory(metal)
+    endif()
+    if (LLAMA_BUILD_SERVER)
         add_subdirectory(server)
     endif()
 endif()
index 32247cef77f590502867389c55ad2fc54e69fd33..b5810f28f49012cb80509c0d4ac44776fc2a0fb9 100644 (file)
@@ -299,6 +299,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.use_mmap = false;
         } else if (arg == "--mtest") {
             params.mem_test = true;
+        } else if (arg == "--export") {
+            params.export_cgraph = true;
         } else if (arg == "--verbose-prompt") {
             params.verbose_prompt = true;
         } else if (arg == "-r" || arg == "--reverse-prompt") {
@@ -438,6 +440,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "                        number of layers to store in VRAM\n");
 #endif
     fprintf(stderr, "  --mtest               compute maximum memory usage\n");
+    fprintf(stderr, "  --export              export the computation graph to 'llama.ggml'\n");
     fprintf(stderr, "  --verbose-prompt      print prompt before generation\n");
     fprintf(stderr, "  --lora FNAME          apply LoRA adapter (implies --no-mmap)\n");
     fprintf(stderr, "  --lora-base FNAME     optional model to use as a base for the layers modified by the LoRA adapter\n");
index fea9aa81a355a4985b38bb7b7ca23649c751db3d..66bdeb5e9287dcb63e05eed81c2c31d9d919555a 100644 (file)
@@ -71,6 +71,7 @@ struct gpt_params {
     bool use_mmap          = true;  // use mmap for faster loads
     bool use_mlock         = false; // use mlock to keep model in memory
     bool mem_test          = false; // compute maximum memory usage
+    bool export_cgraph     = false; // export the computation graph
     bool verbose_prompt    = false; // print prompt tokens before generation
 };
 
index 57cc1e486bda9fd2d249e6d139d5fbc32693fee7..b4d129393255da8d3a079eedda58c0e214a12133 100644 (file)
@@ -134,6 +134,13 @@ int main(int argc, char ** argv) {
         return 0;
     }
 
+    // export the cgraph and exit
+    if (params.export_cgraph) {
+        llama_eval_export(ctx, "llama.ggml");
+        llama_free(ctx);
+
+        return 0;
+    }
 
     std::string path_session = params.path_prompt_cache;
     std::vector<llama_token> session_tokens;
diff --git a/examples/metal/CMakeLists.txt b/examples/metal/CMakeLists.txt
new file mode 100644 (file)
index 0000000..a8c4284
--- /dev/null
@@ -0,0 +1,3 @@
+set(TEST_TARGET metal)
+add_executable(${TEST_TARGET} metal.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
diff --git a/examples/metal/metal.cpp b/examples/metal/metal.cpp
new file mode 100644 (file)
index 0000000..77aca94
--- /dev/null
@@ -0,0 +1,102 @@
+// Evaluate a statically exported ggml computation graph with Metal
+//
+// - First, export a LLaMA graph:
+//
+//  $ ./bin/main -m ../models/7B/ggml-model-q4_0.bin --export
+//
+// - Run this tool to evaluate the exported graph:
+//
+//  $ ./bin/metal llama.ggml
+//
+// The purpose of this tool is mostly for debugging and demonstration purposes.
+// The main limitation of exporting computation graphs is that their sizes are static which often
+// can be a problem for real-world applications.
+//
+
+#include "ggml.h"
+#include "ggml-metal.h"
+
+#include <cstdio>
+#include <cstring>
+#include <cstdlib>
+
+int main(int argc, char ** argv) {
+    ggml_time_init();
+
+    if (argc != 2) {
+        fprintf(stderr, "Usage: %s llama.ggml\n", argv[0]);
+        return -1;
+    }
+
+    const char * fname_cgraph = argv[1];
+
+    // load the compute graph
+    struct ggml_context * ctx_data = NULL;
+    struct ggml_context * ctx_eval = NULL;
+
+    struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
+    gf.n_threads = 1;
+
+    // this allocates all Metal resources and memory buffers
+    auto * ctx_metal = ggml_metal_init();
+
+    ggml_metal_add_buffer(ctx_metal, "data", ggml_get_mem_buffer(ctx_data), ggml_get_mem_size(ctx_data));
+    ggml_metal_add_buffer(ctx_metal, "eval", ggml_get_mem_buffer(ctx_eval), ggml_get_mem_size(ctx_eval));
+
+    // main
+    {
+        struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "embd");
+        *(int32_t *) input->data = 1; // BOS
+
+        ggml_metal_set_tensor(ctx_metal, input);
+
+        // warmup
+        ggml_metal_graph_compute(ctx_metal, &gf);
+
+        const int n_iter = 16;
+
+        const int64_t t0 = ggml_time_us();
+
+        // the actual inference happens here
+        for (int i = 0; i < n_iter; ++i) {
+            ggml_metal_graph_compute(ctx_metal, &gf);
+        }
+
+        const int64_t t1 = ggml_time_us();
+
+        printf("time: %.2f ms, %.2f ms/tok\n", (t1 - t0) / 1000.0, (t1 - t0) / 1000.0 / n_iter);
+    }
+
+    // debug output
+    {
+        struct ggml_tensor * logits = gf.nodes[gf.n_nodes - 1];
+        ggml_metal_get_tensor(ctx_metal, logits);
+
+        float * ptr = (float *) ggml_get_data(logits);
+
+        printf("logits: ");
+        for (int i = 0; i < 10; i++) {
+            printf("%8.4f ", ptr[i]);
+        }
+        printf("\n");
+        int imax = 0;
+        double sum = 0.0;
+        double vmax = -1e9;
+        for (int i = 0; i < 32000; i++) {
+            sum += (double) ptr[i];
+            if (ptr[i] > vmax) {
+                vmax = ptr[i];
+                imax = i;
+            }
+        }
+        printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
+    }
+
+    ggml_metal_free(ctx_metal);
+
+    ggml_free(ctx_data);
+    ggml_free(ctx_eval);
+
+    return 0;
+}
+
diff --git a/ggml-metal.h b/ggml-metal.h
new file mode 100644 (file)
index 0000000..a9441a9
--- /dev/null
@@ -0,0 +1,63 @@
+// An interface allowing to compute ggml_cgraph with Metal
+//
+// This is a fully functional interface that extends ggml with GPU support for Apple devices.
+// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.)
+//
+// How it works?
+//
+// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
+// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
+// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
+//
+// You only need to make sure that all memory buffers that you used during the graph creation
+// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
+// used during the graph evaluation to determine the arguments of the compute kernels.
+//
+// Synchronization between device and host memory (for example for input and output tensors)
+// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
+//
+
+#pragma once
+
+#include <stddef.h>
+#include <stdbool.h>
+
+// max memory buffers that can be mapped to the device
+#define GGML_METAL_MAX_BUFFERS 16
+
+struct ggml_tensor;
+struct ggml_cgraph;
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct ggml_metal_context;
+
+struct ggml_metal_context * ggml_metal_init(void);
+void ggml_metal_free(struct ggml_metal_context * ctx);
+
+// creates a mapping between a host memory buffer and a device memory buffer
+// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
+// - the mapping is used during computation to determine the arguments of the compute kernels
+// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal
+//
+bool ggml_metal_add_buffer(
+        struct ggml_metal_context * ctx,
+                       const char * name,
+                             void * data,
+                           size_t   size);
+
+// set data from host memory into the device
+void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
+
+// get data from the device into host memory
+void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
+
+// same as ggml_graph_compute but uses Metal
+void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
+
+#ifdef __cplusplus
+}
+#endif
+
diff --git a/ggml-metal.m b/ggml-metal.m
new file mode 100644 (file)
index 0000000..3cb423a
--- /dev/null
@@ -0,0 +1,672 @@
+#import "ggml-metal.h"
+
+#import "ggml.h"
+
+#import <Foundation/Foundation.h>
+
+#import <Metal/Metal.h>
+#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
+
+#ifdef GGML_METAL_NDEBUG
+#define metal_printf(...)
+#else
+#define metal_printf(...) fprintf(stderr, __VA_ARGS__)
+#endif
+
+#define UNUSED(x) (void)(x)
+
+struct ggml_metal_buffer {
+    const char * name;
+
+    void   * data;
+    size_t   size;
+
+    id<MTLBuffer> metal;
+};
+
+struct ggml_metal_context {
+    float * logits;
+
+    id<MTLDevice>       device;
+    id<MTLCommandQueue> queue;
+    id<MTLLibrary>      library;
+
+    int n_buffers;
+    struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
+
+    // custom kernels
+#define GGML_METAL_DECL_KERNEL(name) \
+    id<MTLFunction>             function_##name; \
+    id<MTLComputePipelineState> pipeline_##name
+
+    GGML_METAL_DECL_KERNEL(add);
+    GGML_METAL_DECL_KERNEL(mul);
+    GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
+    GGML_METAL_DECL_KERNEL(scale);
+    GGML_METAL_DECL_KERNEL(silu);
+    GGML_METAL_DECL_KERNEL(relu);
+    GGML_METAL_DECL_KERNEL(soft_max);
+    GGML_METAL_DECL_KERNEL(diag_mask_inf);
+    GGML_METAL_DECL_KERNEL(get_rows_q4_0);
+    GGML_METAL_DECL_KERNEL(rms_norm);
+    GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
+    GGML_METAL_DECL_KERNEL(rope);
+    GGML_METAL_DECL_KERNEL(cpy_f32_f16);
+    GGML_METAL_DECL_KERNEL(cpy_f32_f32);
+
+#undef GGML_METAL_DECL_KERNEL
+};
+
+// MSL code
+// TODO: move the contents here when ready
+//       for now it is easier to work in a separate file
+static NSString * const msl_library_source = @"see metal.metal";
+
+struct ggml_metal_context * ggml_metal_init(void) {
+    fprintf(stderr, "%s: allocating\n", __func__);
+
+    struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
+
+    ctx->device = MTLCreateSystemDefaultDevice();
+    ctx->queue  = [ctx->device newCommandQueue];
+
+    // determine if we can use MPS
+    if (MPSSupportsMTLDevice(ctx->device)) {
+        fprintf(stderr, "%s: using MPS\n", __func__);
+    } else {
+        fprintf(stderr, "%s: not using MPS\n", __func__);
+        GGML_ASSERT(false && "MPS not supported");
+    }
+
+#if 0
+    // compile from source string and show compile log
+    {
+        NSError * error = nil;
+
+        ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
+        if (error) {
+            fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
+            exit(1);
+        }
+    }
+#else
+    UNUSED(msl_library_source);
+
+    // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
+    {
+        NSError * error = nil;
+
+        //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
+        NSString * path = [[NSBundle mainBundle] pathForResource:@"ggml-metal" ofType:@"metal"];
+        fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
+
+        NSString * src  = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
+        if (error) {
+            fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
+            exit(1);
+        }
+
+        ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
+        if (error) {
+            fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
+            exit(1);
+        }
+    }
+#endif
+
+    // load kernels
+    {
+#define GGML_METAL_ADD_KERNEL(name) \
+        ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
+        ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
+        fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
+
+        GGML_METAL_ADD_KERNEL(add);
+        GGML_METAL_ADD_KERNEL(mul);
+        GGML_METAL_ADD_KERNEL(mul_row);
+        GGML_METAL_ADD_KERNEL(scale);
+        GGML_METAL_ADD_KERNEL(silu);
+        GGML_METAL_ADD_KERNEL(relu);
+        GGML_METAL_ADD_KERNEL(soft_max);
+        GGML_METAL_ADD_KERNEL(diag_mask_inf);
+        GGML_METAL_ADD_KERNEL(get_rows_q4_0);
+        GGML_METAL_ADD_KERNEL(rms_norm);
+        GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
+        GGML_METAL_ADD_KERNEL(rope);
+        GGML_METAL_ADD_KERNEL(cpy_f32_f16);
+        GGML_METAL_ADD_KERNEL(cpy_f32_f32);
+
+#undef GGML_METAL_ADD_KERNEL
+    }
+
+    return ctx;
+}
+
+void ggml_metal_free(struct ggml_metal_context * ctx) {
+    fprintf(stderr, "%s: deallocating\n", __func__);
+
+    free(ctx);
+}
+
+// finds the Metal buffer that contains the tensor data on the GPU device
+// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
+// Metal buffer based on the host memory pointer
+//
+static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
+    //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
+
+    for (int i = 0; i < ctx->n_buffers; ++i) {
+        const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
+
+        if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
+            *offs = (size_t) ioffs;
+
+            //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
+
+            return ctx->buffers[i].metal;
+        }
+    }
+
+    fprintf(stderr, "%s: error: buffer is nil\n", __func__);
+
+    return nil;
+}
+
+bool ggml_metal_add_buffer(
+        struct ggml_metal_context * ctx,
+                     const char * name,
+                           void * data,
+                         size_t   size) {
+    if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
+        fprintf(stderr, "%s: too many buffers\n", __func__);
+        return false;
+    }
+
+    if (data) {
+        // verify that the buffer does not overlap with any of the existing buffers
+        for (int i = 0; i < ctx->n_buffers; ++i) {
+            const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
+
+            if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
+                fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
+                return false;
+            }
+        }
+
+        ctx->buffers[ctx->n_buffers].name = name;
+        ctx->buffers[ctx->n_buffers].data = data;
+        ctx->buffers[ctx->n_buffers].size = size;
+        ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytes:data length:size options:MTLResourceStorageModeShared];
+
+        ++ctx->n_buffers;
+
+        fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, size / 1024.0 / 1024.0);
+    }
+
+    return true;
+}
+
+void ggml_metal_set_tensor(
+        struct ggml_metal_context * ctx,
+        struct ggml_tensor * t) {
+    metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
+
+    size_t offs;
+    id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
+
+    memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t));
+}
+
+void ggml_metal_get_tensor(
+        struct ggml_metal_context * ctx,
+        struct ggml_tensor * t) {
+    metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
+
+    size_t offs;
+    id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
+
+    memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
+}
+
+void ggml_metal_graph_compute(
+        struct ggml_metal_context * ctx,
+             struct ggml_cgraph * gf) {
+    metal_printf("%s: evaluating graph\n", __func__);
+
+    size_t offs_src0 = 0;
+    size_t offs_src1 = 0;
+    size_t offs_dst  = 0;
+
+    id<MTLCommandBuffer> command_buffer  = [ctx->queue commandBuffer];
+    id<MTLComputeCommandEncoder> encoder = nil;
+
+    for (int i = 0; i < gf->n_nodes; ++i) {
+        //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
+
+        struct ggml_tensor * src0 = gf->nodes[i]->src0;
+        struct ggml_tensor * src1 = gf->nodes[i]->src1;
+        struct ggml_tensor * dst  = gf->nodes[i];
+
+        const int64_t  ne00 = src0 ? src0->ne[0] : 0;
+        const int64_t  ne01 = src0 ? src0->ne[1] : 0;
+        const int64_t  ne02 = src0 ? src0->ne[2] : 0;
+        const int64_t  ne03 = src0 ? src0->ne[3] : 0;
+
+        const uint64_t nb00 = src0 ? src0->nb[0] : 0;
+        const uint64_t nb01 = src0 ? src0->nb[1] : 0;
+        const uint64_t nb02 = src0 ? src0->nb[2] : 0;
+        const uint64_t nb03 = src0 ? src0->nb[3] : 0;
+
+        const int64_t  ne10 = src1 ? src1->ne[0] : 0;
+        const int64_t  ne11 = src1 ? src1->ne[1] : 0;
+        const int64_t  ne12 = src1 ? src1->ne[2] : 0;
+        const int64_t  ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+
+        const uint64_t nb10 = src1 ? src1->nb[0] : 0;
+        const uint64_t nb11 = src1 ? src1->nb[1] : 0;
+        const uint64_t nb12 = src1 ? src1->nb[2] : 0;
+        const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+
+        const int64_t  ne0  = dst ? dst->ne[0] : 0;
+        const int64_t  ne1  = dst ? dst->ne[1] : 0;
+        const int64_t  ne2  = dst ? dst->ne[2] : 0;
+        const int64_t  ne3  = dst ? dst->ne[3] : 0;
+
+        const uint64_t nb0  = dst ? dst->nb[0] : 0;
+        const uint64_t nb1  = dst ? dst->nb[1] : 0;
+        const uint64_t nb2  = dst ? dst->nb[2] : 0;
+        const uint64_t nb3  = dst ? dst->nb[3] : 0;
+
+        const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+        const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+        const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
+
+        id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
+        id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
+        id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(ctx, dst,  &offs_dst)  : nil;
+
+        //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+        //if (src0) {
+        //    metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
+        //            ggml_is_contiguous(src0), src0->name);
+        //}
+        //if (src1) {
+        //    metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
+        //            ggml_is_contiguous(src1), src1->name);
+        //}
+        //if (dst) {
+        //    metal_printf("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
+        //            dst->name);
+        //}
+
+        switch (dst->op) {
+            case GGML_OP_RESHAPE:
+            case GGML_OP_VIEW:
+            case GGML_OP_TRANSPOSE:
+            case GGML_OP_PERMUTE:
+                {
+                    // noop
+                } break;
+            case GGML_OP_ADD:
+                {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    [encoder setComputePipelineState:ctx->pipeline_add];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+            case GGML_OP_MUL:
+                {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    if (ggml_nelements(src1) == ne10) {
+                        // src1 is a row
+                        [encoder setComputePipelineState:ctx->pipeline_mul_row];
+                    } else {
+                        [encoder setComputePipelineState:ctx->pipeline_mul];
+                    }
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                    [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+            case GGML_OP_SCALE:
+                {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    const float scale = *(const float *) src1->data;
+
+                    [encoder setComputePipelineState:ctx->pipeline_scale];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                    [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+            case GGML_OP_SILU:
+                {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    [encoder setComputePipelineState:ctx->pipeline_silu];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+            case GGML_OP_RELU:
+                {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    [encoder setComputePipelineState:ctx->pipeline_relu];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+            case GGML_OP_SOFT_MAX:
+                {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    const int nth = 32;
+
+                    [encoder setComputePipelineState:ctx->pipeline_soft_max];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                    [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                    [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                    [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                    [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                } break;
+            case GGML_OP_DIAG_MASK_INF:
+                {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    const int n_past = ((int32_t *)(src1->data))[0];
+
+                    [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                    [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
+                    [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
+                    [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+            case GGML_OP_MUL_MAT:
+                {
+                    // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
+
+                    GGML_ASSERT(ne00 == ne10);
+                    GGML_ASSERT(ne02 == ne12);
+
+                    if (ggml_is_contiguous(src0) &&
+                        ggml_is_contiguous(src1) &&
+                        (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
+
+                        if (encoder != nil) {
+                            [encoder endEncoding];
+                            encoder = nil;
+                        }
+
+                        MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
+                        MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
+
+                        // for F32 x F32 we use MPS
+                        MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
+                            matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
+
+                        MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
+                            matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
+
+                        MPSMatrixDescriptor * desc  = [MPSMatrixDescriptor
+                            matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
+
+                        MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
+                            initWithDevice:ctx->device transposeLeft:false transposeRight:true
+                                resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
+
+                        // we need to do ne02 multiplications
+                        // TODO: is there a way to do this in parallel - currently very slow ..
+                        // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
+                        for (int64_t i02 = 0; i02 < ne02; ++i02) {
+                            size_t offs_src0_cur = offs_src0 + i02*nb02;
+                            size_t offs_src1_cur = offs_src1 + i02*nb12;
+                            size_t offs_dst_cur  = offs_dst  + i02*nb2;
+
+                            MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
+                            MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
+                            MPSMatrix * mat_dst  = [[MPSMatrix alloc] initWithBuffer:id_dst  offset:offs_dst_cur  descriptor:desc ];
+
+                            [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
+                        }
+                    } else {
+                        if (encoder == nil) {
+                            encoder = [command_buffer computeCommandEncoder];
+                        }
+
+                        int nth0 = 32;
+                        int nth1 = 1;
+
+                        // use custom matrix x vector kernel
+                        switch (src0t) {
+                            case GGML_TYPE_Q4_0:
+                                {
+                                    GGML_ASSERT(ne02 == 1);
+                                    GGML_ASSERT(ne12 == 1);
+
+                                    nth0 = 8;
+                                    nth1 = 4;
+                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
+                                } break;
+                            case GGML_TYPE_F16:
+                                {
+                                    GGML_ASSERT(ne02 == ne12);
+
+                                    nth0 = 32;
+                                    nth1 = 1;
+                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                } break;
+                            default: GGML_ASSERT(false && "not implemented");
+                        };
+
+
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                        [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                        [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                        [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
+                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
+                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
+                        [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
+                        [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
+                        [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
+                        [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
+                        [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
+                        [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:13];
+                        [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
+
+                        if (src0t == GGML_TYPE_Q4_0) {
+                            [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                        } else {
+                            [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                        }
+                    }
+                } break;
+            case GGML_OP_GET_ROWS:
+                {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    switch (src0->type) {
+                        case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
+                        default: GGML_ASSERT(false && "not implemented");
+                    }
+
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                    [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
+                    [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
+                    [encoder setBytes:&(dst->nb[1])  length:sizeof(uint64_t) atIndex:5];
+
+                    const int64_t n = ggml_nelements(src1);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+            case GGML_OP_RMS_NORM:
+                {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    const float eps = 1e-6f;
+
+                    const int nth = 256;
+
+                    [encoder setComputePipelineState:ctx->pipeline_rms_norm];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                    [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                    [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+                    [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
+                    [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+                    const int64_t nrows = ggml_nrows(src0);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                } break;
+            case GGML_OP_ROPE:
+                {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    const int n_dims = ((int32_t *) src1->data)[1];
+                    const int mode   = ((int32_t *) src1->data)[2];
+
+                    const int n_past = ((int32_t *)(src1->data))[0];
+
+                    [encoder setComputePipelineState:ctx->pipeline_rope];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                    [encoder setBytes:&ne00   length:sizeof( int64_t) atIndex:2];
+                    [encoder setBytes:&ne01   length:sizeof( int64_t) atIndex:3];
+                    [encoder setBytes:&ne02   length:sizeof( int64_t) atIndex:4];
+                    [encoder setBytes:&ne03   length:sizeof( int64_t) atIndex:5];
+                    [encoder setBytes:&nb00   length:sizeof(uint64_t) atIndex:6];
+                    [encoder setBytes:&nb01   length:sizeof(uint64_t) atIndex:7];
+                    [encoder setBytes:&nb02   length:sizeof(uint64_t) atIndex:8];
+                    [encoder setBytes:&nb03   length:sizeof(uint64_t) atIndex:9];
+                    [encoder setBytes:&ne0    length:sizeof( int64_t) atIndex:10];
+                    [encoder setBytes:&ne1    length:sizeof( int64_t) atIndex:11];
+                    [encoder setBytes:&ne2    length:sizeof( int64_t) atIndex:12];
+                    [encoder setBytes:&ne3    length:sizeof( int64_t) atIndex:13];
+                    [encoder setBytes:&nb0    length:sizeof(uint64_t) atIndex:14];
+                    [encoder setBytes:&nb1    length:sizeof(uint64_t) atIndex:15];
+                    [encoder setBytes:&nb2    length:sizeof(uint64_t) atIndex:16];
+                    [encoder setBytes:&nb3    length:sizeof(uint64_t) atIndex:17];
+                    [encoder setBytes:&n_past length:sizeof(     int) atIndex:18];
+                    [encoder setBytes:&n_dims length:sizeof(     int) atIndex:19];
+                    [encoder setBytes:&mode   length:sizeof(     int) atIndex:20];
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+            case GGML_OP_CPY:
+                {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    const int nth = 32;
+
+                    switch (src0t) {
+                        case GGML_TYPE_F32:
+                            {
+                                switch (dstt) {
+                                    case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
+                                    case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
+                                    default: GGML_ASSERT(false && "not implemented");
+                                };
+                            } break;
+                        default: GGML_ASSERT(false && "not implemented");
+                    }
+
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                    [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                    [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+                    [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+                    [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+                    [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+                    [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+                    [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+                    [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+                    [encoder setBytes:&ne0  length:sizeof( int64_t) atIndex:10];
+                    [encoder setBytes:&ne1  length:sizeof( int64_t) atIndex:11];
+                    [encoder setBytes:&ne2  length:sizeof( int64_t) atIndex:12];
+                    [encoder setBytes:&ne3  length:sizeof( int64_t) atIndex:13];
+                    [encoder setBytes:&nb0  length:sizeof(uint64_t) atIndex:14];
+                    [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
+                    [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
+                    [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                } break;
+            default:
+                fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                GGML_ASSERT(false);
+        }
+    }
+
+    if (encoder != nil) {
+        [encoder endEncoding];
+        encoder = nil;
+    }
+
+    [command_buffer commit];
+    [command_buffer waitUntilCompleted];
+
+    {
+        const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
+        UNUSED(time_elapsed);
+
+        metal_printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
+    }
+}
diff --git a/ggml-metal.metal b/ggml-metal.metal
new file mode 100644 (file)
index 0000000..4bedc8e
--- /dev/null
@@ -0,0 +1,489 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+#define MAX(x, y) ((x) > (y) ? (x) : (y))
+
+#define QK4_0 32
+#define QR4_0 2
+typedef struct {
+    half    d;             // delta
+    uint8_t qs[QK4_0 / 2]; // nibbles / quants
+} block_q4_0;
+
+static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
+    const int qk = QK4_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const half d = x[i].d;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F) - 8;
+            const int x1 = (x[i].qs[j] >>   4) - 8;
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
+kernel void kernel_add(
+        device const float * src0,
+        device const float * src1,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] + src1[tpig];
+}
+
+kernel void kernel_mul(
+        device const float * src0,
+        device const float * src1,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src1[tpig];
+}
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_mul_row(
+        device const float * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src1[tpig % ne00];
+}
+
+kernel void kernel_scale(
+        device const float * src0,
+        device       float * dst,
+        constant     float & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_silu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    float x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_relu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = max(0.0f, src0[tpig]);
+}
+
+kernel void kernel_soft_max(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    // parallel max
+    buf[tpitg[0]] = -INFINITY;
+    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+        buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
+    }
+
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg[0]/2; i > 0; i /= 2) {
+        if (tpitg[0] < i) {
+            buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    // broadcast
+    if (tpitg[0] == 0) {
+        buf[0] = buf[0];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    const float max = buf[0];
+
+    // parallel sum
+    buf[tpitg[0]] = 0.0f;
+    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+        buf[tpitg[0]] += exp(psrc0[i00] - max);
+    }
+
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg[0]/2; i > 0; i /= 2) {
+        if (tpitg[0] < i) {
+            buf[tpitg[0]] += buf[tpitg[0] + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    // broadcast
+    if (tpitg[0] == 0) {
+        buf[0] = buf[0];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    const float sum = buf[0];
+
+    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+        pdst[i00] = exp(psrc0[i00] - max) / sum;
+    }
+}
+
+kernel void kernel_diag_mask_inf(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant       int & n_past,
+        uint3 tpig[[thread_position_in_grid]]) {
+    const int64_t i02 = tpig[2];
+    const int64_t i01 = tpig[1];
+    const int64_t i00 = tpig[0];
+
+    if (i00 > n_past + i01) {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+    } else {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+    }
+}
+
+kernel void kernel_get_rows_q4_0(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q4_0(
+            (device const block_q4_0 *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_rms_norm(
+        device const  void * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant     float & eps,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+
+    // parallel sum
+    sum[tpitg] = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        sum[tpitg] += x[i00] * x[i00];
+    }
+
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg/2; i > 0; i /= 2) {
+        if (tpitg < i) {
+            sum[tpitg] += sum[tpitg + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    // broadcast
+    if (tpitg == 0) {
+        sum[0] /= ne00;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    const float mean  = sum[0];
+    const float scale = 1.0f/sqrt(mean + eps);
+
+    device float * y = dst + tgpig*ne00;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = x[i00] * scale;
+    }
+}
+
+kernel void kernel_mul_mat_q4_0_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2  tpig[[thread_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+    const int nb = ne00/QK4_0;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+
+    const uint nth = tptg.x*tptg.y;
+    const uint ith = tptg.y*tpitg.x + tpitg.y;
+
+    sum[ith] = 0.0f;
+
+    for (int i = tpitg.x; i < nb; i += tptg.x) {
+        device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
+        device const float4 * y0p = (device const float4 *) (y + i*QK4_0);
+
+        const float d = (float)((x + i)->d);
+
+        const uchar4 x0v = *(x0p + tpitg.y);
+        const float4 y0v = *(y0p + tpitg.y + 0);
+        const float4 y1v = *(y0p + tpitg.y + 4);
+
+        float acc = 0.0f;
+
+        for (int j = 0; j < 4; ++j) {
+            const int x0 = x0v[j] & 0x0F;
+            const int x1 = x0v[j] >>   4;
+
+            const float y0 = y0v[j];
+            const float y1 = y1v[j];
+
+            acc += (x0 - 8)*y0 + (x1 - 8)*y1;
+        }
+
+        sum[ith] += acc*d;
+    }
+
+    // accumulate the sum from all threads in the threadgroup
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = nth/2; i > 0; i /= 2) {
+        if (ith < i) {
+            sum[ith] += sum[ith + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    if (ith == 0) {
+        dst[r1*ne0 + r0] = sum[0];
+    }
+}
+
+kernel void kernel_mul_mat_f16_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tpig[[thread_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3  tptg[[threads_per_threadgroup]]) {
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+    const int64_t im = tgpig.z;
+
+    device const half  * x = (device const half  *) (src0 + r0*nb01 + im*nb02);
+    device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+    sum[tpitg.x] = 0.0f;
+
+    for (int i = tpitg.x; i < ne00; i += tptg.x) {
+        sum[tpitg.x] += (float) x[i] * (float) y[i];
+    }
+
+    // accumulate the sum from all threads in the threadgroup
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = tptg.x/2; i > 0; i /= 2) {
+        if (tpitg.x < i) {
+            sum[tpitg.x] += sum[tpitg.x + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    if (tpitg.x == 0) {
+        dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
+    }
+}
+
+kernel void kernel_rope(
+        device const  void * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        constant       int & n_past,
+        constant       int & n_dims,
+        constant       int & mode,
+        uint3 tpig[[thread_position_in_grid]]) {
+    const int64_t i3 = tpig[2];
+    const int64_t i2 = tpig[1];
+    const int64_t i1 = tpig[0];
+
+    const bool is_neox = mode & 2;
+    const float theta_scale = pow(10000.0, -2.0f/n_dims);
+
+    const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
+
+    float theta = (float)p;
+
+    if (!is_neox) {
+        for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+            const float cos_theta = cos(theta);
+            const float sin_theta = sin(theta);
+
+            theta *= theta_scale;
+
+            device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       float * dst_data  = (device float *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[1];
+
+            dst_data[0] = x0*cos_theta - x1*sin_theta;
+            dst_data[1] = x0*sin_theta + x1*cos_theta;
+        }
+    } else {
+        // TODO: implement
+    }
+}
+
+kernel void kernel_cpy_f32_f16(
+        device const float * src0,
+        device        half * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f32_f32(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
diff --git a/ggml.c b/ggml.c
index 91552c94caade83346c2b9ad3f7ea928f6039165..00bbee503f52a755eafdb90c00fc5af241538535 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -3723,7 +3723,7 @@ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
     return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
 }
 
-int ggml_nrows(const struct ggml_tensor * tensor) {
+int64_t ggml_nrows(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
@@ -3732,7 +3732,14 @@ int ggml_nrows(const struct ggml_tensor * tensor) {
 size_t ggml_nbytes(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
-    return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
+    // this should handle cases where the tensor is not contiguous in memory
+    // probaby just:
+    //
+    //     return tensor->ne[3]*tensor->nb[3]
+    //
+    // is enough, but just in case, adding the second part
+
+    return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]);
 }
 
 int ggml_blck_size(enum ggml_type type) {
@@ -3814,11 +3821,11 @@ size_t ggml_tensor_overhead(void) {
     return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16;
 }
 
-static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
+bool ggml_is_transposed(const struct ggml_tensor * tensor) {
     return tensor->nb[0] > tensor->nb[1];
 }
 
-static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
+bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
@@ -5802,10 +5809,18 @@ struct ggml_tensor * ggml_view_1d(
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
 
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    memcpy(offs->data, &offset, 2*sizeof(int32_t));
+
+    ggml_scratch_load(ctx);
+
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = NULL;
+    result->opt[0] = offs;
 
     if (is_node) {
         memcpy(result->padding, &offset, sizeof(offset));
@@ -5834,6 +5849,13 @@ struct ggml_tensor * ggml_view_2d(
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
 
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    memcpy(offs->data, &offset, 2*sizeof(int32_t));
+
+    ggml_scratch_load(ctx);
+
     result->nb[1] = nb1;
     result->nb[2] = result->nb[1]*ne1;
     result->nb[3] = result->nb[2];
@@ -5842,6 +5864,7 @@ struct ggml_tensor * ggml_view_2d(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = NULL;
+    result->opt[0] = offs;
 
     if (is_node) {
         memcpy(result->padding, &offset, sizeof(offset));
@@ -5872,6 +5895,13 @@ struct ggml_tensor * ggml_view_3d(
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
 
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    memcpy(offs->data, &offset, 2*sizeof(int32_t));
+
+    ggml_scratch_load(ctx);
+
     result->nb[1] = nb1;
     result->nb[2] = nb2;
     result->nb[3] = result->nb[2]*ne2;
@@ -5880,6 +5910,7 @@ struct ggml_tensor * ggml_view_3d(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = NULL;
+    result->opt[0] = offs;
 
     if (is_node) {
         memcpy(result->padding, &offset, sizeof(offset));
@@ -5912,6 +5943,13 @@ struct ggml_tensor * ggml_view_4d(
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
 
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    memcpy(offs->data, &offset, 2*sizeof(int32_t));
+
+    ggml_scratch_load(ctx);
+
     result->nb[1] = nb1;
     result->nb[2] = nb2;
     result->nb[3] = nb3;
@@ -5920,6 +5958,7 @@ struct ggml_tensor * ggml_view_4d(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = NULL;
+    result->opt[0] = offs;
 
     if (is_node) {
         memcpy(result->padding, &offset, sizeof(offset));
@@ -9252,7 +9291,7 @@ static void ggml_compute_forward_rms_norm_f32(
                     sum += (ggml_float)(x[i00] * x[i00]);
                 }
 
-                float mean = sum/ne00;
+                const float mean = sum/ne00;
 
                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 
@@ -11163,7 +11202,7 @@ static void ggml_compute_forward_rope_f32(
                         theta *= theta_scale;
 
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float x0 = src[0];
                         const float x1 = src[1];
@@ -11184,7 +11223,7 @@ static void ggml_compute_forward_rope_f32(
                             const int64_t i0 = ib*n_dims + ic/2;
 
                             const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+                                  float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                             const float x0 = src[0];
                             const float x1 = src[n_dims/2];
@@ -14588,7 +14627,7 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou
     const int64_t * ne = tensor->ne;
     const size_t  * nb = tensor->nb;
 
-    fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %16s\n",
+    fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %32s\n",
             ggml_type_name(tensor->type),
             ggml_op_name  (tensor->op),
             tensor->n_dims,
@@ -14602,7 +14641,7 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
     const int64_t * ne = tensor->ne;
     const size_t  * nb = tensor->nb;
 
-    fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %16s\n",
+    fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %32s\n",
             arg,
             ggml_type_name(tensor->type),
             ggml_op_name  (tensor->op),
@@ -14615,8 +14654,8 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
 }
 
 void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
-    assert(cgraph->work      == NULL);
-    assert(cgraph->work_size == 0);
+    //assert(cgraph->work      == NULL);
+    //assert(cgraph->work_size == 0);
 
     uint64_t size_eval = 0;
 
@@ -14837,7 +14876,6 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
     // read file into data
     {
         FILE * fin = fopen(fname, "rb");
-
         if (!fin) {
             fprintf(stderr, "%s: failed to open %s\n", __func__, fname);
             return result;
@@ -14977,6 +15015,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                 op     = *(const uint32_t *) ptr; ptr += sizeof(op);
                 n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
 
+                enum ggml_op eop = (enum ggml_op) op;
+
                 int64_t ne[GGML_MAX_DIMS];
                 size_t  nb[GGML_MAX_DIMS];
 
@@ -14991,42 +15031,77 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                     nb[j] = nb_cur;
                 }
 
-                struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
-
-                tensor->op = (enum ggml_op) op;
+                uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); // TODO: not yet used
 
-                uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur);
+                const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
 
-                memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME;
+                const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
 
-                for (int j = 0; j < GGML_MAX_DIMS; ++j) {
-                    tensor->nb[j] = nb[j];
-                }
+                struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
 
                 // parse args
-                {
-                    struct ggml_tensor ** args[2 + GGML_MAX_OPT] = {
-                        &tensor->src0,
-                        &tensor->src1,
-                    };
+                for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
+                    const int32_t arg_idx = ptr_arg_idx[j];
 
-                    for (int j = 0; j < GGML_MAX_OPT; ++j) {
-                        args[2 + j] = &tensor->opt[j];
+                    if (arg_idx == -1) {
+                        continue;
                     }
 
-                    for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
-                        const int32_t arg_idx = *(const int32_t *) ptr; ptr += sizeof(arg_idx);
+                    if (arg_idx < GGML_MAX_NODES) {
+                        args[j] = result.leafs[arg_idx];
+                    } else {
+                        args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
+                    }
+                }
 
-                        if (arg_idx == -1) {
-                            continue;
-                        }
+                // create the tensor
+                // "view" operations are handled differently
+                // TODO: handle inplace ops - currently a copy is always made
 
-                        if (arg_idx < GGML_MAX_NODES) {
-                            *args[j] = result.leafs[arg_idx];
-                        } else {
-                            *args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
-                        }
-                    }
+                struct ggml_tensor * tensor = NULL;
+
+                switch (eop) {
+                    // TODO: implement other view ops
+                    case GGML_OP_RESHAPE:
+                        {
+                            tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]);
+                        } break;
+                    case GGML_OP_VIEW:
+                        {
+                            tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
+
+                            uint64_t offs;
+                            memcpy(&offs, args[2]->data, sizeof(offs));
+
+                            tensor->data = ((char *) tensor->data) + offs;
+                        } break;
+                    case GGML_OP_TRANSPOSE:
+                        {
+                            tensor = ggml_transpose(*ctx_eval, args[0]);
+                        } break;
+                    case GGML_OP_PERMUTE:
+                        {
+                            tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
+                        } break;
+                    default:
+                        {
+                            tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
+
+                            tensor->op = eop;
+                        } break;
+                }
+
+                memcpy(tensor->name, ptr_name, GGML_MAX_NAME);
+
+                for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+                    tensor->nb[j] = nb[j];
+                }
+
+                tensor->src0 = args[0];
+                tensor->src1 = args[1];
+
+                for (int j = 0; j < GGML_MAX_OPT; ++j) {
+                    tensor->opt[j] = args[2 + j];
                 }
 
                 result.nodes[i] = tensor;
diff --git a/ggml.h b/ggml.h
index 60c0ad8bfa1c0ffeb7ea2be39b29c606ff3479dd..2ea87ce9a974931910bbe80bcaead847ee40f99c 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -425,6 +425,7 @@ extern "C" {
     GGML_API void    ggml_print_objects(const struct ggml_context * ctx);
 
     GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor);
+    GGML_API int64_t ggml_nrows    (const struct ggml_tensor * tensor);
     GGML_API size_t  ggml_nbytes   (const struct ggml_tensor * tensor);
 
     GGML_API int     ggml_blck_size (enum ggml_type type);
@@ -441,13 +442,16 @@ extern "C" {
     // TODO: temporary until model loading of ggml examples is refactored
     GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
 
+    GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
+
     // use this to compute the memory overhead of a tensor
     GGML_API size_t ggml_tensor_overhead(void);
 
     // main
 
     GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
-    GGML_API void    ggml_free(struct ggml_context * ctx);
+    GGML_API void                  ggml_free(struct ggml_context * ctx);
 
     GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
 
index f70b26c0fb108054ce94086029d42633b5df145f..bc58ad960c1390e1a30ef4cc448677accc1160b5 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
 #include "ggml-opencl.h"
 #endif
 
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
 #include <array>
 #include <ctime>
 #include <cinttypes>
@@ -243,6 +247,10 @@ struct llama_context {
     llama_ctx_buffer buf_compute;
     llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
 
+#ifdef GGML_USE_METAL
+    ggml_metal_context * ctx_metal = NULL;
+#endif
+
     int    buf_last = 0;
     size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
 
@@ -1088,7 +1096,7 @@ static void llama_model_load_internal(
             mmapped_size - vram_total + // weights in VRAM not in memory
             MEM_REQ_SCRATCH0().at(model.type) +
             MEM_REQ_SCRATCH1().at(model.type) +
-            MEM_REQ_EVAL().at(model.type);
+            MEM_REQ_EVAL().at    (model.type);
 
         // this is the memory required by one llama_state
         const size_t mem_required_state =
@@ -1195,17 +1203,19 @@ static bool llama_model_load(
 
 // evaluate the transformer
 //
-//   - lctx:      llama context
-//   - tokens:    new batch of tokens to process
-//   - n_past:    the context size so far
-//   - n_threads: number of threads to use
+//   - lctx:         llama context
+//   - tokens:       new batch of tokens to process
+//   - n_past:       the context size so far
+//   - n_threads:    number of threads to use
+//   - cgraph_fname: filename of the exported computation graph
 //
 static bool llama_eval_internal(
-        llama_context & lctx,
-    const llama_token * tokens,
-            const int   n_tokens,
-            const int   n_past,
-            const int   n_threads) {
+        llama_context &  lctx,
+    const llama_token *  tokens,
+            const int    n_tokens,
+            const int    n_past,
+            const int    n_threads,
+            const char * cgraph_fname) {
 
     // enforce that the first token is BOS
     if (n_past == 0 && tokens[0] != llama_token_bos()) {
@@ -1251,13 +1261,18 @@ static bool llama_eval_internal(
     ggml_set_name(embd, "embd");
     memcpy(embd->data, tokens, N*ggml_element_size(embd));
 
+#ifdef GGML_USE_METAL
+    if (lctx.ctx_metal && N == 1) {
+        ggml_metal_set_tensor(lctx.ctx_metal, embd);
+    }
+#endif
+
+    struct ggml_tensor * cur;
     struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
 
     for (int il = 0; il < n_layer; ++il) {
         struct ggml_tensor * inpSA = inpL;
 
-        struct ggml_tensor * cur;
-
         lctx.use_buf(ctx0, 0);
 
         // norm
@@ -1271,6 +1286,7 @@ static bool llama_eval_internal(
         // self-attention
         {
             // compute Q and K and RoPE them
+
             struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
             struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
             ggml_set_name(Qcur, "Qcur");
@@ -1280,6 +1296,7 @@ static bool llama_eval_internal(
             {
                 // compute the transposed [N, n_embd] V matrix
                 struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
+                ggml_set_name(Vcur, "Vcur");
 
                 struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
                 struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
@@ -1325,7 +1342,6 @@ static bool llama_eval_internal(
             struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
             ggml_set_name(KQ_soft_max, "KQ_soft_max");
 
-
             // split cached V into n_head heads
             struct ggml_tensor * V =
                 ggml_view_3d(ctx0, kv_self.v,
@@ -1407,26 +1423,53 @@ static bool llama_eval_internal(
 
     // norm
     {
+        cur = ggml_rms_norm(ctx0, inpL);
 
-        inpL = ggml_rms_norm(ctx0, inpL);
+        // cur = cur*norm(broadcasted)
+        cur = ggml_mul(ctx0, cur, model.norm);
 
-        // inpL = inpL*norm(broadcasted)
-        inpL = ggml_mul(ctx0, inpL, model.norm);
-
-        embeddings = inpL;
+        embeddings = cur;
     }
 
     // lm_head
-    inpL = ggml_mul_mat(ctx0, model.output, inpL);
+    cur = ggml_mul_mat(ctx0, model.output, cur);
 
     lctx.use_buf(ctx0, -1);
 
     // logits -> probs
-    //inpL = ggml_soft_max_inplace(ctx0, inpL);
+    //cur = ggml_soft_max_inplace(ctx0, cur);
 
     // run the computation
-    ggml_build_forward_expand(&gf, inpL);
-    ggml_graph_compute       (ctx0, &gf);
+    ggml_build_forward_expand(&gf, cur);
+
+#ifdef GGML_USE_METAL
+    if (lctx.ctx_metal && N == 1) {
+        ggml_metal_graph_compute(lctx.ctx_metal, &gf);
+        ggml_metal_get_tensor   (lctx.ctx_metal, cur);
+    } else {
+        // IMPORTANT:
+        // Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla
+        // ggml_graph_compute(). It uses Apple's Accelerate CBLAS API which takes advantage of the ANE or the AMX
+        // coprocessor.
+        //
+        // When we implement Matrix x Matrix Metal multiplication, we can avoid this branch.
+        // But for now, we have focused only on Matrix x Vector Metal multiplication.
+        //
+        ggml_graph_compute(ctx0, &gf);
+
+        if (lctx.ctx_metal) {
+            // We need to sync the CPU KV cache with the GPU KV cache
+            ggml_metal_set_tensor(lctx.ctx_metal, kv_self.k);
+            ggml_metal_set_tensor(lctx.ctx_metal, kv_self.v);
+        }
+    }
+#else
+    ggml_graph_compute(ctx0, &gf);
+#endif
+
+    if (cgraph_fname) {
+        ggml_graph_export(&gf, cgraph_fname);
+    }
 
 #ifdef GGML_PERF
     // print timing information per ggml operation (for debugging purposes)
@@ -1440,7 +1483,7 @@ static bool llama_eval_internal(
     //}
 
     //embd_w.resize(n_vocab*N);
-    //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+    //memcpy(embd_w.data(), ggml_get_data(cur), sizeof(float)*n_vocab*N);
 
     // update kv token count
     lctx.model.kv_self.n = n_past + N;
@@ -1451,11 +1494,11 @@ static bool llama_eval_internal(
 
         if (lctx.logits_all) {
             logits_out.resize(n_vocab * N);
-            memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+            memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N);
         } else {
             // return result for just the last token
             logits_out.resize(n_vocab);
-            memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+            memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
         }
     }
 
@@ -2251,8 +2294,8 @@ struct llama_context * llama_init_from_file(
     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
     if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type,
-                          params.use_mmap, params.use_mlock, params.vocab_only,
-                          params.progress_callback, params.progress_callback_user_data)) {
+                params.use_mmap, params.use_mlock, params.vocab_only,
+                params.progress_callback, params.progress_callback_user_data)) {
         fprintf(stderr, "%s: failed to load model\n", __func__);
         llama_free(ctx);
         return nullptr;
@@ -2290,6 +2333,25 @@ struct llama_context * llama_init_from_file(
         ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
     }
 
+#ifdef GGML_USE_METAL
+    if (params.n_gpu_layers > 0) {
+        // this allocates all Metal resources and memory buffers
+        ctx->ctx_metal = ggml_metal_init();
+
+        if (params.use_mmap) {
+            ggml_metal_add_buffer(ctx->ctx_metal, "data", ctx->model.mapping->addr, ctx->model.mapping->size);
+            ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr,    ctx->buf_compute.size);
+        } else {
+            ggml_metal_add_buffer(ctx->ctx_metal, "data", ggml_get_mem_buffer(ctx->model.ctx), ggml_get_mem_size(ctx->model.ctx));
+            ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr,               ctx->buf_compute.size);
+        }
+
+        ggml_metal_add_buffer(ctx->ctx_metal, "kv",   ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size);
+        ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr,    ctx->buf_scratch[0].size);
+        ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr,    ctx->buf_scratch[1].size);
+    }
+#endif
+
     return ctx;
 }
 
@@ -2905,7 +2967,7 @@ int llama_eval(
                          int   n_tokens,
                          int   n_past,
                          int   n_threads) {
-    if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) {
+    if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) {
         fprintf(stderr, "%s: failed to eval\n", __func__);
         return 1;
     }
@@ -2920,6 +2982,20 @@ int llama_eval(
     return 0;
 }
 
+int llama_eval_export(struct llama_context * ctx, const char * fname) {
+    const int n_batch = 1;
+    const int n_ctx   = 512 - n_batch;
+
+    const std::vector<llama_token> tmp(n_batch, llama_token_bos());
+
+    if (!llama_eval_internal(*ctx, tmp.data(), tmp.size(), n_ctx, 1, fname)) {
+        fprintf(stderr, "%s: failed to eval\n", __func__);
+        return 1;
+    }
+
+    return 0;
+}
+
 int llama_tokenize(
         struct llama_context * ctx,
                   const char * text,
diff --git a/llama.h b/llama.h
index c6b0a2889f8de1d70d7fd1eb96e19e2659a22c32..87fa9736784c8055a8bb39fd2295b6c8af174951 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -31,7 +31,7 @@
 #define LLAMA_SESSION_MAGIC          LLAMA_FILE_MAGIC_GGSN
 #define LLAMA_SESSION_VERSION        1
 
-#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
 // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
 #define LLAMA_SUPPORTS_GPU_OFFLOAD
 #endif
@@ -173,6 +173,12 @@ extern "C" {
                              int   n_past,
                              int   n_threads);
 
+    // Export a static computation graph for context of 511 and batch size of 1
+    // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
+    //       parameters here to keep things simple
+    // IMPORTANT: do not use for anything else other than debugging and testing!
+    LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname);
+
     // Convert the provided text into tokens.
     // The tokens pointer must be large enough to hold the resulting tokens.
     // Returns the number of tokens on success, no more than n_max_tokens