]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml-backend update: buffer types, backend registry, graph compare, tests (#620)
authorslaren <redacted>
Thu, 30 Nov 2023 18:03:03 +0000 (19:03 +0100)
committerGitHub <redacted>
Thu, 30 Nov 2023 18:03:03 +0000 (19:03 +0100)
* ggml-backend update

* update metal backend

* show metal logs with ggml-backend

* move buffer types to functions

* cuda: add per-device backends

* cuda: add host buffer type

* fix metal build

* ggml_backend_alloc_ctx_tensors : ignore allocated tensors

* ggml_backend_compare_graph_backend fixes

* ci : try to fix metal build

* metal : first print device info, then build kernels

* ci : disable GGML_METAL on Github Actions

* test-backend-ops initial impl (unary and get_rows)

* more op tests

* cleanup

* print test params, add more tests cases for add and mul

* add tests for im2col

* better f16 init

* metal : add basic impl of supports_op

* add test for ggml_concat

* update im2col test params, show callstack with GGML_ASSERT on CUDA failures

* add more rope tests

* add more rope and mul_mat test cases

* add more get_rows test cases
ggml-ci

* add more norm and rms_norm test cases with different eps

* ci : fix metal resource path

ggml-ci

* tests : silence warning

* add ggml_backend_tensor_alloc and ggml_backend_view_init for initializing tensors without ggml-alloc

* add mul_mat test cases without dims 3 and 4
ggml-ci

* check for nans and infs
ggml-ci

* add diag_mask_inf test cases without dims 3 and 4
ggml-ci

* fix cuda leak while backend reg

* fix msvc issues

* remove backend_sched debug causes by default

* gpt-2 : increase graph size

ggml-ci

---------

Co-authored-by: Georgi Gerganov <redacted>
24 files changed:
.github/workflows/ci.yml
ci/run.sh
examples/gpt-2/main-backend.cpp
examples/gpt-2/main-batched.cpp
examples/gpt-2/main.cpp
examples/whisper/whisper.cpp
include/ggml/ggml-alloc.h
include/ggml/ggml-backend.h
include/ggml/ggml.h
src/ggml-alloc.c
src/ggml-backend-impl.h
src/ggml-backend.c
src/ggml-cuda.cu
src/ggml-cuda.h
src/ggml-impl.h
src/ggml-metal.h
src/ggml-metal.m
src/ggml.c
tests/CMakeLists.txt
tests/test-backend-buffer.cpp [new file with mode: 0644]
tests/test-backend-ops.cpp [new file with mode: 0644]
tests/test-conv1d.cpp
tests/test-conv2d.cpp
tests/test-mul-mat.cpp

index 4e5c633af55332fad0135edf5d86243d1cc66e00..e5543d92139fe2914dd2513c2b4c000342a03922 100644 (file)
@@ -20,7 +20,7 @@ jobs:
     - name: Dependencies
       run: |
         wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | sudo tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null
-        echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list  
+        echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list
         sudo apt-get update
         sudo apt-get install -y --no-install-recommends llvm intel-oneapi-runtime-opencl intel-oneapi-runtime-compilers libclblast-dev
     - name: Create Build Environment
@@ -59,7 +59,7 @@ jobs:
 
     - name: Configure CMake
       working-directory: ./build
-      run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON -DGGML_METAL=ON ..
+      run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON ..
 
     - name: Build
       working-directory: ./build
index a593090f37b80e6ddeac6739b76c276c61829d0c..61fbef19e8b788578045a3436eadef85e23a5c58 100644 (file)
--- a/ci/run.sh
+++ b/ci/run.sh
@@ -94,6 +94,10 @@ function gg_run_ctest_debug {
     (time cmake -DCMAKE_BUILD_TYPE=Debug ${CMAKE_EXTRA} ..     ) 2>&1 | tee -a $OUT/${ci}-cmake.log
     (time make -j                                              ) 2>&1 | tee -a $OUT/${ci}-make.log
 
+    if [ ! -z ${GG_BUILD_METAL} ]; then
+        export GGML_METAL_PATH_RESOURCES="$(pwd)/bin"
+    fi
+
     (time ctest --output-on-failure -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log
 
     set +e
@@ -122,6 +126,10 @@ function gg_run_ctest_release {
     (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} ..   ) 2>&1 | tee -a $OUT/${ci}-cmake.log
     (time make -j                                              ) 2>&1 | tee -a $OUT/${ci}-make.log
 
+    if [ ! -z ${GG_BUILD_METAL} ]; then
+        export GGML_METAL_PATH_RESOURCES="$(pwd)/bin"
+    fi
+
     if [ -z $GG_BUILD_LOW_PERF ]; then
         (time ctest --output-on-failure ) 2>&1 | tee -a $OUT/${ci}-ctest.log
     else
index e986cdee1112a2b3d80e9bde20c7a2088f76141d..9ef873ece3d12ab21921f0037b27fa0393f4fba8 100644 (file)
@@ -26,6 +26,8 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
+#define GPT2_MAX_NODES 4096
+
 static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
     (void) level;
     (void) user_data;
@@ -177,47 +179,6 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
 
     auto & ctx = model.ctx;
 
-    size_t buffer_size = 0;
-
-    {
-        const auto & hparams = model.hparams;
-
-        const int n_embd  = hparams.n_embd;
-        const int n_layer = hparams.n_layer;
-        const int n_ctx   = hparams.n_ctx;
-        const int n_vocab = hparams.n_vocab;
-
-        buffer_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
-        buffer_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
-
-        buffer_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // wte
-        buffer_size +=   n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
-        buffer_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // lm_head
-
-        buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
-        buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
-
-        buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
-        buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
-
-        buffer_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype));         // c_attn_attn_w
-        buffer_size += n_layer*(       3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
-
-        buffer_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype));           // c_attn_proj_w
-        buffer_size += n_layer*(       n_embd*ggml_type_sizef(GGML_TYPE_F32));   // c_attn_proj_b
-
-        buffer_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_fc_w
-        buffer_size += n_layer*(       4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
-
-        buffer_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_proj_w
-        buffer_size += n_layer*(         n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
-
-        buffer_size += (6 + 12*n_layer)*128; // alignment overhead
-
-        printf("%s: ggml tensor size    = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
-        printf("%s: backend buffer size = %6.2f MB\n", __func__, buffer_size/(1024.0*1024.0));
-    }
-
     // create the ggml context
     {
         size_t n_tensors = 2 + 6 + 12*model.hparams.n_layer;
@@ -227,8 +188,8 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
             /*.no_alloc   =*/ true,
         };
 
-        model.ctx = ggml_init(params);
-        if (!model.ctx) {
+        ctx = ggml_init(params);
+        if (!ctx) {
             fprintf(stderr, "%s: ggml_init() failed\n", __func__);
             return false;
         }
@@ -238,7 +199,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
 #ifdef GGML_USE_CUBLAS
     if (n_gpu_layers > 0) {
         fprintf(stderr, "%s: using CUDA backend\n", __func__);
-        model.backend = ggml_backend_cuda_init();
+        model.backend = ggml_backend_cuda_init(0);
         if (!model.backend) {
             fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
         }
@@ -267,10 +228,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
         return false;
     }
 
-    // allocate weights buffer
-    model.buffer_w = ggml_backend_alloc_buffer(model.backend, buffer_size);
-
-    // prepare memory for the weights
+    // create the tensors for the model
     {
         const auto & hparams = model.hparams;
 
@@ -338,6 +296,12 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
         }
     }
 
+    // allocate the model tensors in a backend buffer
+    model.buffer_w = ggml_backend_alloc_ctx_tensors(ctx, model.backend);
+
+    printf("%s: ggml tensor size    = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
+    printf("%s: backend buffer size = %6.2f MB\n", __func__, ggml_backend_buffer_get_size(model.buffer_w)/(1024.0*1024.0));
+
     // override the default training context with the user-provided
     model.hparams.n_ctx = n_ctx;
 
@@ -378,8 +342,6 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
 
     // load weights
     {
-        ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer_w);
-
         size_t total_size = 0;
 
         bool has_lm_head = false;
@@ -440,8 +402,6 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
                 return false;
             }
 
-            ggml_allocr_alloc(alloc, tensor);
-
             if (ggml_backend_is_cpu  (model.backend)
 #ifdef GGML_USE_METAL
                 || ggml_backend_is_metal(model.backend)
@@ -470,7 +430,6 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
             total_size += ggml_nbytes(tensor);
         }
 
-        ggml_allocr_free(alloc);
         printf("%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
     }
 
@@ -495,7 +454,7 @@ struct ggml_cgraph * gpt2_graph(
     const int n_head  = hparams.n_head;
 
     // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
-    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
+    static size_t buf_size = ggml_tensor_overhead()*GPT2_MAX_NODES + ggml_graph_overhead_custom(GPT2_MAX_NODES, false);
     static std::vector<uint8_t> buf(buf_size);
 
     struct ggml_init_params params = {
@@ -506,7 +465,7 @@ struct ggml_cgraph * gpt2_graph(
 
     struct ggml_context * ctx0 = ggml_init(params);
 
-    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);
+    struct ggml_cgraph  * gf = ggml_new_graph_custom(ctx0, GPT2_MAX_NODES, false);
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     ggml_allocr_alloc(allocr, embd);
@@ -801,16 +760,56 @@ bool gpt2_eval(
     // allocate tensors
     ggml_allocr_alloc_graph(allocr, gf);
 
-    // run the computation
+    // set backend options
     if (ggml_backend_is_cpu(model.backend)) {
         ggml_backend_cpu_set_n_threads(model.backend, n_threads);
     }
+
 #ifdef GGML_USE_METAL
     if (ggml_backend_is_metal(model.backend)) {
         ggml_backend_metal_set_n_cb(model.backend, n_threads);
     }
 #endif
-    ggml_backend_graph_compute(model.backend, gf);
+
+    // test
+#if 0 && defined(GGML_USE_CUBLAS)
+    if (ggml_backend_is_cuda(model.backend)) {
+        auto eval_callback = [](int index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data) {
+            auto tv1 = tensor_to_float(t1);
+            auto tv2 = tensor_to_float(t2);
+
+#if 1
+            float sim = cosine_similarity(tv1, tv2);
+            float len1 = vec_len(tv1);
+            float len2 = vec_len(tv2);
+            float lenr = len1/len2;
+            float lenrd = std::abs(1.0f-lenr);
+
+            float angle = acosf(sim)*180.0f/M_PI;
+
+            if (angle > 0.5f || lenrd > 0.05f) {
+                printf("%3d [%15s] %s: sim = %f, a = %f, lenrd = %f\n", index, ggml_op_desc(t1), t1->name, sim, angle, lenrd);
+            }
+            assert(sim > 0.90f);
+#else
+            float dist = distance(tv1, tv2) / vec_len(tv1);
+            if (dist > 0.01f) {
+                printf("%3d [%15s] %s: distance = %f\n", index, ggml_op_desc(t1), t1->name, dist);
+            }
+#endif
+
+            return true;
+        };
+        ggml_backend_t backend_cpu = ggml_backend_cpu_init();
+        ggml_backend_compare_graph_backend(model.backend, backend_cpu, gf, eval_callback, nullptr);
+        ggml_backend_free(backend_cpu);
+        //printf("done\n");
+    } else
+#endif
+    {
+        // run the computation
+        ggml_backend_graph_compute(model.backend, gf);
+    }
 
     //if (n_past%100 == 0) {
     //    ggml_graph_print   (&gf);
index 41d531aa0a2261e07446bbccf4afe18b0ff8f9a8..8c998f0ba7ae1eb0a96f74b2cad84a80849fa0ac 100644 (file)
@@ -27,6 +27,8 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
+#define GPT2_MAX_NODES 4096
+
 static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
     (void) level;
     (void) user_data;
@@ -286,7 +288,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
 #ifdef GGML_USE_CUBLAS
     if (n_gpu_layers > 0) {
         fprintf(stderr, "%s: using CUDA backend\n", __func__);
-        model.backend = ggml_backend_cuda_init();
+        model.backend = ggml_backend_cuda_init(0);
         if (!model.backend) {
             fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
         }
@@ -551,7 +553,7 @@ struct ggml_cgraph * gpt2_graph(
     const int32_t kv_head  = ggml_allocr_is_measure(allocr) ? n_ctx - n_tokens : kv_cache.head;
 
     // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
-    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
+    static size_t buf_size = ggml_tensor_overhead()*GPT2_MAX_NODES + ggml_graph_overhead_custom(GPT2_MAX_NODES, false);
     static std::vector<uint8_t> buf(buf_size);
 
     struct ggml_init_params params = {
@@ -562,7 +564,7 @@ struct ggml_cgraph * gpt2_graph(
 
     struct ggml_context * ctx0 = ggml_init(params);
 
-    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);
+    struct ggml_cgraph  * gf = ggml_new_graph_custom(ctx0, GPT2_MAX_NODES, false);
 
     struct ggml_tensor * inpL;
     if (batch.token) {
index 1a249254d60a8a02eb576d8bbf9ed355f25a3f5d..08a9aa3f873ed98eb858616ae618908efecc505e 100644 (file)
@@ -26,6 +26,8 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
+#define GPT2_MAX_NODES 4096
+
 static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
     (void) level;
     (void) user_data;
@@ -107,7 +109,7 @@ void init_backends(gpt2_model & model, const gpt_params & params) {
 #ifdef GGML_USE_CUBLAS
     if (params.n_gpu_layers > 0) {
         fprintf(stderr, "%s: using CUDA backend\n", __func__);
-        gpu_backend = ggml_backend_cuda_init();
+        gpu_backend = ggml_backend_cuda_init(0);
         if (!gpu_backend) {
             fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
         }
@@ -553,7 +555,7 @@ struct ggml_cgraph * gpt2_graph(
     const int n_head  = hparams.n_head;
 
     // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
-    static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
+    static size_t buf_size = ggml_tensor_overhead()*GPT2_MAX_NODES + ggml_graph_overhead_custom(GPT2_MAX_NODES, false);
     static std::vector<uint8_t> buf(buf_size);
 
     struct ggml_init_params params = {
@@ -564,7 +566,7 @@ struct ggml_cgraph * gpt2_graph(
 
     struct ggml_context * ctx0 = ggml_init(params);
 
-    struct ggml_cgraph  * gf = ggml_new_graph(ctx0);
+    struct ggml_cgraph  * gf = ggml_new_graph_custom(ctx0, GPT2_MAX_NODES, false);
 
     struct ggml_tensor * embd = ggml_view_1d(ctx0, model.embd, N, 0);
 
index e2197ff262c993ae7ac14ed393e1d116242fe45f..f0a0a5a6d37916d8015c040cfd3f6db8ee9239c7 100644 (file)
@@ -1063,7 +1063,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
 #ifdef GGML_USE_CUBLAS
     if (params.use_gpu && ggml_cublas_loaded()) {
         WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
-        backend_gpu = ggml_backend_cuda_init();
+        backend_gpu = ggml_backend_cuda_init(0);
         if (!backend_gpu) {
             WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
         }
index dde2a06bf803098a85d60c15a47953753e4a0bd8..ad87cebc8873f4584ad99e2c67e110a603f4a71c 100644 (file)
@@ -8,6 +8,7 @@ extern "C" {
 
 struct ggml_backend;
 struct ggml_backend_buffer;
+struct ggml_backend_buffer_type;
 
 //
 // Legacy API
@@ -80,6 +81,12 @@ GGML_API void   ggml_gallocr_alloc_graph_n(
                     struct ggml_hash_set hash_set,
                     ggml_tallocr_t * hash_node_talloc);
 
+
+// Utils
+// Create a buffer and allocate all the tensors in a ggml_context
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
+
 #ifdef  __cplusplus
 }
 #endif
index 966687320ac96d971e9d635429d4ea1018a7ce57..58d5ccae6ed101ca80df9390386adc5329536722 100644 (file)
@@ -7,41 +7,44 @@
 extern "C" {
 #endif
 
+    typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+    typedef struct ggml_backend * ggml_backend_t;
+    typedef void * ggml_backend_graph_plan_t;
+
     //
     // Backend buffer
     //
 
-    struct ggml_backend_buffer;
-    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+    // buffer type
+    GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
+    GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
+    GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
+    GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
 
-    // backend buffer functions
+    // buffer
     GGML_API void   ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
-    GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
     GGML_API void * ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
     GGML_API size_t ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
-    GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
     GGML_API void   ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
-    GGML_API void   ggml_backend_buffer_free_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
+    GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer);
 
     //
     // Backend
     //
 
-    struct ggml_backend;
-    typedef struct ggml_backend * ggml_backend_t;
-    typedef void * ggml_backend_graph_plan_t;
-
-    GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
 
     GGML_API const char * ggml_backend_name(ggml_backend_t backend);
     GGML_API void         ggml_backend_free(ggml_backend_t backend);
 
-    GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
-
-    GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
+    GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
+    GGML_API ggml_backend_buffer_t      ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
+    GGML_API size_t                     ggml_backend_get_alignment(ggml_backend_t backend);
 
-    GGML_API void ggml_backend_tensor_set_async(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
-    GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
 
     GGML_API void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
     GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
@@ -57,6 +60,7 @@ extern "C" {
 
     // tensor copy between different backends
     GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
+    GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy
 
     //
     // CPU backend
@@ -68,8 +72,23 @@ extern "C" {
     GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
 
     // Create a backend buffer from an existing pointer
-    GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
+    GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
+
+    GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
 
+    //
+    // Backend registry
+    //
+
+    // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
+
+    GGML_API size_t                     ggml_backend_reg_get_count(void);
+    GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name);
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params]
+    GGML_API const char *               ggml_backend_reg_get_name(size_t i);
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
+    GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
+    GGML_API ggml_backend_buffer_t      ggml_backend_reg_alloc_buffer(size_t i, size_t size);
 
     //
     // Backend scheduler
@@ -131,6 +150,32 @@ extern "C" {
             ggml_backend_sched_t sched,
             struct ggml_cgraph * graph);
 
+
+    //
+    // Utils
+    //
+
+    struct ggml_backend_graph_copy {
+        ggml_backend_buffer_t buffer;
+        struct ggml_context * ctx_allocated;
+        struct ggml_context * ctx_unallocated;
+        struct ggml_cgraph * graph;
+    };
+
+    // Copy a graph to a different backend
+    GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
+    GGML_API void                           ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
+
+    typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
+
+    // Compare the output of two backends
+    GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
+
+    // Tensor initialization
+    GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
+    GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+
+
 #ifdef  __cplusplus
 }
 #endif
index 8e6b646066b7a488197becae814d17e504916194..b53abaa16a025c5d210db40e9f8b67e0809cb1e1 100644 (file)
@@ -449,7 +449,9 @@ extern "C" {
         GGML_UNARY_OP_GELU,
         GGML_UNARY_OP_GELU_QUICK,
         GGML_UNARY_OP_SILU,
-        GGML_UNARY_OP_LEAKY
+        GGML_UNARY_OP_LEAKY,
+
+        GGML_UNARY_OP_COUNT,
     };
 
     enum ggml_object_type {
@@ -632,6 +634,9 @@ extern "C" {
     GGML_API const char * ggml_op_name  (enum ggml_op   op);
     GGML_API const char * ggml_op_symbol(enum ggml_op   op);
 
+    GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
+    GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
+
     GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
 
     GGML_API bool    ggml_is_quantized(enum ggml_type type);
@@ -1749,7 +1754,7 @@ extern "C" {
     GGML_API struct ggml_cgraph * ggml_new_graph         (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
     GGML_API struct ggml_cgraph * ggml_new_graph_custom  (struct ggml_context * ctx, size_t size, bool grads);
     GGML_API struct ggml_cgraph * ggml_graph_dup         (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
-    GGML_API struct ggml_cgraph * ggml_graph_view        (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1);
+    GGML_API struct ggml_cgraph   ggml_graph_view        (struct ggml_cgraph * cgraph, int i0, int i1);
     GGML_API void                 ggml_graph_cpy         (struct ggml_cgraph * src, struct ggml_cgraph * dst);
     GGML_API void                 ggml_graph_reset       (struct ggml_cgraph * cgraph);  // zero grads
     GGML_API void                 ggml_graph_clear       (struct ggml_cgraph * cgraph);
index cdfe4caf69613d6778f8a80af9b0d6ffc8e67652..80f2a4975f457f75badeb37250c83233eb9e02c9 100644 (file)
@@ -135,6 +135,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
         ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
     }
 
+
 #ifdef GGML_ALLOCATOR_DEBUG
     add_allocated_tensor(alloc, tensor);
     size_t cur_max = (char*)addr - (char*)alloc->data + size;
@@ -168,10 +169,6 @@ static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor *
     size = aligned_offset(NULL, size, alloc->alignment);
     AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
 
-    if (!alloc->measure) {
-        ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
-    }
-
 #ifdef GGML_ALLOCATOR_DEBUG
     remove_allocated_tensor(alloc, tensor);
 #endif
@@ -237,7 +234,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
 }
 
 ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
-    struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
+    struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size);
 
     ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
 
@@ -449,7 +446,6 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
 static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
     ggml_tallocr_t alloc = node_tallocr(galloc, view);
 
-    //printf("init_view: %s from src %s\n", view->name, view->view_src->name);
     GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
     if (update_backend) {
         view->backend = view->view_src->backend;
@@ -459,7 +455,7 @@ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool upd
 
     // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
     // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
-    assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
+    assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
 
     if (!alloc->measure) {
         ggml_backend_buffer_init_tensor(alloc->buffer, view);
@@ -765,3 +761,43 @@ size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
 size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
     return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
 }
+
+// utils
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
+
+    size_t alignment = ggml_backend_buft_get_alignment(buft);
+
+    size_t nbytes = 0;
+    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        if (t->data == NULL && t->view_src == NULL) {
+            nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
+        }
+    }
+
+    if (nbytes == 0) {
+        fprintf(stderr, "%s: no tensors to allocate\n", __func__);
+        return NULL;
+    }
+
+    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
+    ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
+
+    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        if (t->data == NULL) {
+            if (t->view_src == NULL) {
+                ggml_tallocr_alloc(tallocr, t);
+            } else {
+                ggml_backend_view_init(buffer, t);
+            }
+        }
+    }
+
+    ggml_tallocr_free(tallocr);
+
+    return buffer;
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
+    return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
+}
index 211e3d4247387b2b598caca2e79d6863ffdf4c35..d623b885197bb692fe8cfd916d04f849d488d88d 100644 (file)
@@ -12,31 +12,50 @@ extern "C" {
     // Backend buffer
     //
 
+    // buffer type
+    typedef void * ggml_backend_buffer_type_context_t;
+
+    struct ggml_backend_buffer_type_i {
+        ggml_backend_buffer_t (*alloc_buffer)    (ggml_backend_buffer_type_t buft, size_t size);
+        size_t                (*get_alignment)   (ggml_backend_buffer_type_t buft); // tensor alignment
+        size_t                (*get_alloc_size)  (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
+        bool                  (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
+    };
+
+    struct ggml_backend_buffer_type {
+        struct ggml_backend_buffer_type_i  iface;
+        ggml_backend_buffer_type_context_t context;
+    };
+
+    // buffer
     typedef void * ggml_backend_buffer_context_t;
 
     struct ggml_backend_buffer_i {
-        void   (*free_buffer)   (ggml_backend_buffer_t buffer);
-        void * (*get_base)      (ggml_backend_buffer_t buffer); // get base pointer
-        size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
-        void   (*init_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
-        void   (*free_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
+        void     (*free_buffer)(ggml_backend_buffer_t buffer);
+        //void     (*reset)      (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
+        void *   (*get_base)   (ggml_backend_buffer_t buffer);
+        void     (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+        void     (*set_tensor) (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void     (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        // (optional) copy tensor between different buffer-type, allow for single-copy tranfers
+        void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
+        void (*cpy_tensor_to)  (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
     };
 
     struct ggml_backend_buffer {
-        struct ggml_backend_buffer_i iface;
-
-        ggml_backend_t                backend;
+        struct ggml_backend_buffer_i  iface;
+        ggml_backend_buffer_type_t    buft;
         ggml_backend_buffer_context_t context;
-
         size_t size;
     };
 
-    GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
-            struct ggml_backend                  * backend,
+    ggml_backend_buffer_t ggml_backend_buffer_init(
+                   ggml_backend_buffer_type_t      buft,
             struct ggml_backend_buffer_i           iface,
                    ggml_backend_buffer_context_t   context,
                    size_t                          size);
 
+
     //
     // Backend
     //
@@ -49,20 +68,17 @@ extern "C" {
         void (*free)(ggml_backend_t backend);
 
         // buffer allocation
-        ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
-
-        // get buffer alignment
-        size_t (*get_alignment)(ggml_backend_t backend);
+        ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
 
-        // tensor data access
-        // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
+        // (optional) asynchroneous tensor data access
         void (*set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
         void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
-        void (*synchronize)     (ggml_backend_t backend);
 
-        // (optional) copy tensor between different backends, allow for single-copy tranfers
-        void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
-        void (*cpy_tensor_to)  (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+        // (optional) asynchroneous tensor copy
+        void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+        void (*cpy_tensor_to_async)  (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+
+        void (*synchronize)     (ggml_backend_t backend);
 
         // compute graph with a plan
         ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
@@ -82,6 +98,44 @@ extern "C" {
         ggml_backend_context_t context;
     };
 
+
+    //
+    // Backend registry
+    //
+
+    typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data);
+
+    size_t ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
+
+
+    // Register a int function to be called at program startup
+    #if defined(__GNUC__) || defined(__clang__)
+        #define GGML_CONSTRUCTOR(init_fn) \
+            static void __attribute__((constructor)) init_fn ## _ggml_constructor(void) { \
+                init_fn(); \
+            }
+    #elif defined(_MSC_VER)
+        #ifdef __cplusplus
+            #define GGML_CONSTRUCTOR(init_fn) \
+                static int init_fn ## _ggml_constructor_dummy = init_fn();
+        #else
+            #define GGML_CONSTRUCTOR(init_fn) \
+                __pragma(section(".CRT$XCV", read)) \
+                __declspec(allocate(".CRT$XCV")) int (*init_fn ## _ggml_constructor)(void) = init_fn; \
+                __pragma(comment(linker, "/include:" #init_fn "_ggml_constructor"))
+        #endif
+    #else
+        #error "GGML_CONSTRUCTOR not implemented for this compiler"
+    #endif
+
+
+    // Register a backend
+    #define GGML_BACKEND_REGISTER(name, init_fn, buft, user_data) \
+        static void init_fn ## _backend_register(void) { \
+            ggml_backend_register(name, init_fn, buft, user_data); \
+        } \
+        GGML_CONSTRUCTOR(init_fn ## _backend_register)
+
 #ifdef  __cplusplus
 }
 #endif
index f6e5fceed0f4df2acf14a44953eb0f51df4443b3..e721f9c8979a6056d5411713d4c031859805ca55 100644 (file)
@@ -9,14 +9,36 @@
 #include <stdlib.h>
 #include <string.h>
 
-#define UNUSED GGML_UNUSED
 
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
+
+// backend buffer type
+
+ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    return buft->iface.alloc_buffer(buft, size);
+}
+
+size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_alignment(buft);
+}
+
+size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
+    // get_alloc_size is optional, defaults to ggml_nbytes
+    if (buft->iface.get_alloc_size) {
+        return buft->iface.get_alloc_size(buft, tensor);
+    }
+    return ggml_nbytes(tensor);
+}
+
+bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return buft->iface.supports_backend(buft, backend);
+}
+
 // backend buffer
 
 ggml_backend_buffer_t ggml_backend_buffer_init(
-        struct ggml_backend                  * backend,
+               ggml_backend_buffer_type_t      buft,
         struct ggml_backend_buffer_i           iface,
                ggml_backend_buffer_context_t   context,
                size_t                          size) {
@@ -26,7 +48,7 @@ ggml_backend_buffer_t ggml_backend_buffer_init(
 
     (*buffer) = (struct ggml_backend_buffer) {
         /* .interface = */ iface,
-        /* .backend   = */ backend,
+        /* .buft      = */ buft,
         /* .context   = */ context,
         /* .size      = */ size,
     };
@@ -45,10 +67,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
     free(buffer);
 }
 
-size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) {
-    return ggml_backend_get_alignment(buffer->backend);
-}
-
 size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
     return buffer->size;
 }
@@ -61,14 +79,6 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
     return base;
 }
 
-size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
-    // get_alloc_size is optional, defaults to ggml_nbytes
-    if (buffer->iface.get_alloc_size) {
-        return buffer->iface.get_alloc_size(buffer, tensor);
-    }
-    return ggml_nbytes(tensor);
-}
-
 void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
     // init_tensor is optional
     if (buffer->iface.init_tensor) {
@@ -76,19 +86,20 @@ void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_t
     }
 }
 
-void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
-    // free_tensor is optional
-    if (buffer->iface.free_tensor) {
-        buffer->iface.free_tensor(buffer, tensor);
-    }
+size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) {
+    return ggml_backend_buft_get_alignment(ggml_backend_buffer_type(buffer));
 }
 
-// backend
+size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type(buffer), tensor);
+}
 
-ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) {
-    return tensor->buffer ? tensor->buffer->backend : NULL;
+ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer) {
+    return buffer->buft;
 }
 
+// backend
+
 const char * ggml_backend_name(ggml_backend_t backend) {
     if (backend == NULL) {
         return "NULL";
@@ -104,43 +115,53 @@ void ggml_backend_free(ggml_backend_t backend) {
     backend->iface.free(backend);
 }
 
+ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
+    return backend->iface.get_default_buffer_type(backend);
+}
+
 ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
-    return backend->iface.alloc_buffer(backend, size);
+    return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size);
 }
 
 size_t ggml_backend_get_alignment(ggml_backend_t backend) {
-    return backend->iface.get_alignment(backend);
+    return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend));
 }
 
-void ggml_backend_tensor_set_async(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
+void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+    backend->iface.set_tensor_async(backend, tensor, data, offset, size);
 }
 
-void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
+void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+
+    backend->iface.get_tensor_async(backend, tensor, data, offset, size);
 }
 
 void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    ggml_backend_t backend = ggml_get_backend(tensor);
-
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-    GGML_ASSERT(backend != NULL && "tensor backend not set");
+    GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
 
-    backend->iface.set_tensor_async(backend, tensor, data, offset, size);
-    backend->iface.synchronize(backend);
+    tensor->buffer->iface.set_tensor(tensor->buffer, tensor, data, offset, size);
 }
 
 void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    ggml_backend_t backend = ggml_get_backend(tensor);
-
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-    GGML_ASSERT(backend != NULL && "tensor backend not set");
+    GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
 
-    backend->iface.get_tensor_async(backend, tensor, data, offset, size);
-    backend->iface.synchronize(backend);
+    tensor->buffer->iface.get_tensor(tensor->buffer, tensor, data, offset, size);
 }
 
 void ggml_backend_synchronize(ggml_backend_t backend) {
+    if (backend->iface.synchronize == NULL) {
+        return;
+    }
+
     backend->iface.synchronize(backend);
 }
 
@@ -154,10 +175,16 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
 
 void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
     backend->iface.graph_plan_compute(backend, plan);
+
+    // TODO: optional sync
+    ggml_backend_synchronize(backend);
 }
 
 void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
     backend->iface.graph_compute(backend, cgraph);
+
+    // TODO: optional sync
+    ggml_backend_synchronize(backend);
 }
 
 bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
@@ -194,14 +221,15 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
 
     // TODO: allow backends to support copy to/from same backend
 
-    if (ggml_get_backend(dst)->iface.cpy_tensor_from != NULL) {
-        ggml_get_backend(dst)->iface.cpy_tensor_from(ggml_get_backend(dst)->context, src, dst);
-    } else if (ggml_get_backend(src)->iface.cpy_tensor_to != NULL) {
-        ggml_get_backend(src)->iface.cpy_tensor_to(ggml_get_backend(src)->context, src, dst);
+    if (dst->buffer->iface.cpy_tensor_from != NULL) {
+        dst->buffer->iface.cpy_tensor_from(dst->buffer, src, dst);
+    } else if (src->buffer->iface.cpy_tensor_to != NULL) {
+        src->buffer->iface.cpy_tensor_to(src->buffer, src, dst);
     } else {
         // shouldn't be hit when copying from/to CPU
         #ifndef NDEBUG
-        fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", ggml_backend_name(src->buffer->backend), ggml_backend_name(dst->buffer->backend));
+        fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to "
+                        "are implemented for %s and %s, falling back to get/set\n", src->name, dst->name);
         #endif
         size_t nbytes = ggml_nbytes(src);
         void * data = malloc(nbytes);
@@ -211,101 +239,218 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
     }
 }
 
-// backend CPU
+// backend registry
 
-struct ggml_backend_cpu_context {
-    int n_threads;
-    void * work_data;
-    size_t work_size;
+struct ggml_backend_reg {
+    char name[128];
+    ggml_backend_init_fn init_fn;
+    ggml_backend_buffer_type_t default_buffer_type;
+    void * user_data;
 };
 
-static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
-    return "CPU";
+#define GGML_MAX_BACKENDS_REG 16
+static struct ggml_backend_reg ggml_backend_registry[GGML_MAX_BACKENDS_REG];
+static size_t ggml_backend_registry_count = 0;
+
+size_t ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
+    GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG);
+
+    int id = ggml_backend_registry_count;
 
-    UNUSED(backend);
+    ggml_backend_registry[id] = (struct ggml_backend_reg) {
+        /* .name                = */ {0},
+        /* .fn                  = */ init_fn,
+        /* .default_buffer_type = */ default_buffer_type,
+        /* .user_data           = */ user_data,
+    };
+
+    snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name);
+
+    fprintf(stderr, "%s: registered backend %s\n", __func__, name);
+
+    ggml_backend_registry_count++;
+    return ggml_backend_registry_count - 1;
 }
 
-static void ggml_backend_cpu_free(ggml_backend_t backend) {
-    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
-    free(cpu_ctx->work_data);
-    free(cpu_ctx);
-    free(backend);
+
+size_t ggml_backend_reg_get_count(void) {
+    return ggml_backend_registry_count;
+}
+
+size_t ggml_backend_reg_find_by_name(const char * name) {
+    for (size_t i = 0; i < ggml_backend_registry_count; i++) {
+        // TODO: case insensitive in a portable way
+        if (strcmp(ggml_backend_registry[i].name, name) == 0) {
+            return i;
+        }
+    }
+    return SIZE_MAX;
+}
+
+// init from backend:params string
+ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str) {
+    const char * params = strchr(backend_str, ':');
+    char backend_name[128];
+    if (params == NULL) {
+        strcpy(backend_name, backend_str);
+        params = "";
+    } else {
+        strncpy(backend_name, backend_str, params - backend_str);
+        backend_name[params - backend_str] = '\0';
+        params++;
+    }
+
+    size_t backend_i = ggml_backend_reg_find_by_name(backend_name);
+    if (backend_i == SIZE_MAX) {
+        fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
+        return NULL;
+    }
+
+    return ggml_backend_reg_init_backend(backend_i, params);
+}
+
+const char * ggml_backend_reg_get_name(size_t i) {
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].name;
+}
+
+ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params) {
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].init_fn(params, ggml_backend_registry[i].user_data);
+}
+
+ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i) {
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].default_buffer_type;
+}
+
+ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) {
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_buft_alloc_buffer(ggml_backend_registry[i].default_buffer_type, size);
 }
 
+// backend CPU
+
 static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
     return (void *)buffer->context;
 }
 
 static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
     free(buffer->context);
-    UNUSED(buffer);
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy((char *)tensor->data + offset, data, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+
+    GGML_UNUSED(buffer);
 }
 
 static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
-    /* .free_buffer    = */ ggml_backend_cpu_buffer_free_buffer,
-    /* .get_base       = */ ggml_backend_cpu_buffer_get_base,
-    /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
-    /* .init_tensor    = */ NULL, // no initialization required
-    /* .free_tensor    = */ NULL, // no cleanup required
+    /* .free_buffer     = */ ggml_backend_cpu_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
+    /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
+    /* .cpy_tensor_to   = */ ggml_backend_cpu_buffer_cpy_tensor_to,
 };
 
 // for buffers from ptr, free is not called
 static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
-    /* .free_buffer    = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
-    /* .get_base       = */ ggml_backend_cpu_buffer_get_base,
-    /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
-    /* .init_tensor    = */ NULL,
-    /* .free_tensor    = */ NULL,
+    /* .free_buffer     = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
+    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
+    /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
+    /* .cpy_tensor_to   = */ ggml_backend_cpu_buffer_cpy_tensor_to,
 };
 
 static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
 
-static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backend, size_t size) {
+static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
     size += TENSOR_ALIGNMENT;   // malloc may return an address that is not aligned
     void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
 
     GGML_ASSERT(data != NULL && "failed to allocate buffer");
 
-    return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size);
+    return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
 }
 
-static size_t ggml_backend_cpu_get_alignment(ggml_backend_t backend) {
+static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     return TENSOR_ALIGNMENT;
-    UNUSED(backend);
-}
 
-static void ggml_backend_cpu_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_UNUSED(buft);
+}
 
-    memcpy((char *)tensor->data + offset, data, size);
+static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return ggml_backend_is_cpu(backend);
 
-    UNUSED(backend);
+    GGML_UNUSED(buft);
 }
 
-static void ggml_backend_cpu_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-
-    memcpy(data, (const char *)tensor->data + offset, size);
+ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_cpu = {
+        /* .iface = */ {
+            /* .alloc_buffer     = */ ggml_backend_cpu_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend,
+        },
+        /* .context = */ NULL,
+    };
 
-    UNUSED(backend);
+    return &ggml_backend_buffer_type_cpu;
 }
 
-static void ggml_backend_cpu_synchronize(ggml_backend_t backend) {
-    UNUSED(backend);
-}
+struct ggml_backend_cpu_context {
+    int n_threads;
+    void * work_data;
+    size_t work_size;
+};
 
-static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
+    return "CPU";
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
-static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+static void ggml_backend_cpu_free(ggml_backend_t backend) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+    free(cpu_ctx->work_data);
+    free(cpu_ctx);
+    free(backend);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) {
+    return ggml_backend_cpu_buffer_type();
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
 struct ggml_backend_plan_cpu {
@@ -334,7 +479,7 @@ static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backen
     free(cpu_plan->cplan.work_data);
     free(cpu_plan);
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
 static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
@@ -342,7 +487,7 @@ static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_bac
 
     ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
 static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
@@ -363,25 +508,25 @@ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_c
 
 static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
     return true;
-    UNUSED(backend);
-    UNUSED(op);
+
+    GGML_UNUSED(backend);
+    GGML_UNUSED(op);
 }
 
 static struct ggml_backend_i cpu_backend_i = {
-    /* .get_name            = */ ggml_backend_cpu_name,
-    /* .free                = */ ggml_backend_cpu_free,
-    /* .alloc_buffer        = */ ggml_backend_cpu_alloc_buffer,
-    /* .get_alignment       = */ ggml_backend_cpu_get_alignment,
-    /* .set_tensor_async    = */ ggml_backend_cpu_set_tensor_async,
-    /* .get_tensor_async    = */ ggml_backend_cpu_get_tensor_async,
-    /* .synchronize         = */ ggml_backend_cpu_synchronize,
-    /* .cpy_tensor_from     = */ ggml_backend_cpu_cpy_tensor_from,
-    /* .cpy_tensor_to       = */ ggml_backend_cpu_cpy_tensor_to,
-    /* .graph_plan_create   = */ ggml_backend_cpu_graph_plan_create,
-    /* .graph_plan_free     = */ ggml_backend_cpu_graph_plan_free,
-    /* .graph_plan_compute  = */ ggml_backend_cpu_graph_plan_compute,
-    /* .graph_compute       = */ ggml_backend_cpu_graph_compute,
-    /* .supports_op         = */ ggml_backend_cpu_supports_op,
+    /* .get_name                = */ ggml_backend_cpu_name,
+    /* .free                    = */ ggml_backend_cpu_free,
+    /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_from_async   = */ NULL,
+    /* .cpy_tensor_to_async     = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ ggml_backend_cpu_graph_plan_create,
+    /* .graph_plan_free         = */ ggml_backend_cpu_graph_plan_free,
+    /* .graph_plan_compute      = */ ggml_backend_cpu_graph_plan_compute,
+    /* .graph_compute           = */ ggml_backend_cpu_graph_compute,
+    /* .supports_op             = */ ggml_backend_cpu_supports_op,
 };
 
 ggml_backend_t ggml_backend_cpu_init(void) {
@@ -411,10 +556,19 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
     ctx->n_threads = n_threads;
 }
 
-ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) {
-    return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size);
+ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
+    return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
+}
+
+static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) {
+    return ggml_backend_cpu_init();
+
+    GGML_UNUSED(params);
+    GGML_UNUSED(user_data);
 }
 
+GGML_BACKEND_REGISTER("CPU", ggml_backend_reg_cpu_init, ggml_backend_cpu_buffer_type(), NULL)
+
 // scheduler
 
 #define GGML_MAX_BACKENDS 4
@@ -427,7 +581,7 @@ struct ggml_backend_sched_split {
     int i_end;
     struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
     int n_inputs;
-    struct ggml_cgraph graph;
+    struct ggml_cgraph graph;
 };
 
 struct ggml_backend_sched {
@@ -453,7 +607,7 @@ struct ggml_backend_sched {
     #else
     __attribute__((aligned(GGML_MEM_ALIGN)))
     #endif
-    char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + GGML_MAX_SPLITS*sizeof(struct ggml_cgraph)];
+    char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
 };
 
 #define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
@@ -482,23 +636,57 @@ static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr)
     return INT_MAX;
 }
 
+static ggml_backend_t get_buffer_backend(ggml_backend_sched_t sched, ggml_backend_buffer_t buffer) {
+    if (buffer == NULL) {
+        return NULL;
+    }
+    // find highest prio backend that supports the buffer type
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
+            return sched->backends[i];
+        }
+    }
+    GGML_ASSERT(false && "tensor buffer type not supported by any backend");
+}
+
+static ggml_backend_t get_allocr_backend(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
+    if (allocr == NULL) {
+        return NULL;
+    }
+    // find highest prio backend that supports the buffer type
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (sched->tallocs[i] == allocr) {
+            return sched->backends[i];
+        }
+    }
+    GGML_UNREACHABLE();
+}
+
+#if 0
+static char causes[GGML_DEFAULT_GRAPH_SIZE*8 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
+#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
+#define GET_CAUSE(node) causes[hash_id(node)]
+#else
+#define SET_CAUSE(node, ...)
+#define GET_CAUSE(node) ""
+#endif
+
 // returns the backend that should be used for the node based on the current locations
-char causes[GGML_DEFAULT_GRAPH_SIZE*4 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
 static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) {
     // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
     // ie. kv cache updates
     // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
     // dst
-    ggml_backend_t cur_backend = ggml_get_backend(node);
+    ggml_backend_t cur_backend = get_buffer_backend(sched, node->buffer);
     if (cur_backend != NULL) {
-        sprintf(causes[hash_id(node)], "1.dst");
+        SET_CAUSE(node, "1.dst");
         return cur_backend;
     }
 
     // view_src
-    if (node->view_src != NULL && ggml_get_backend(node->view_src) != NULL) {
-        sprintf(causes[hash_id(node)], "1.vsrc");
-        return ggml_get_backend(node->view_src);
+    if (node->view_src != NULL && get_buffer_backend(sched, node->view_src->buffer) != NULL) {
+        SET_CAUSE(node, "1.vsrc");
+        return get_buffer_backend(sched, node->view_src->buffer);
     }
 
     // src
@@ -510,7 +698,7 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
         if (src == NULL) {
             break;
         }
-        ggml_backend_t src_backend = ggml_get_backend(src);
+        ggml_backend_t src_backend = get_buffer_backend(sched, src->buffer);
         if (src_backend != NULL) {
             int src_prio = sched_backend_prio(sched, src_backend);
             size_t src_size = ggml_nbytes(src);
@@ -518,7 +706,7 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
                 cur_prio = src_prio;
                 cur_size = src_size;
                 cur_backend = src_backend;
-                sprintf(causes[hash_id(node)], "1.src%d", i);
+                SET_CAUSE(node, "1.src%d", i);
             }
         }
     }
@@ -539,10 +727,12 @@ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgra
     int cur_split = 0;
     for (int i = 0; i < graph->n_nodes; i++) {
         if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
-            ggml_backend_t split_backend = ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend;
-            fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs);
+            ggml_backend_t split_backend = get_allocr_backend(sched, sched->splits[cur_split].tallocr);
+            fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend),
+                sched->splits[cur_split].n_inputs);
             for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
-                fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
+                fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
+                    fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
             }
             fprintf(stderr, "\n");
             cur_split++;
@@ -552,16 +742,18 @@ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgra
             continue;
         }
         ggml_tallocr_t node_allocr = node_allocr(node);
-        ggml_backend_t node_backend = node_allocr ? ggml_tallocr_get_buffer(node_allocr)->backend : NULL;
-        fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]);
+        ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME:
+        fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name,
+            fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node));
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
             if (src == NULL) {
                 break;
             }
             ggml_tallocr_t src_allocr = node_allocr(src);
-            ggml_backend_t src_backend = src_allocr ? ggml_tallocr_get_buffer(src_allocr)->backend : NULL;
-            fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]);
+            ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL;
+            fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name,
+                fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
         }
         fprintf(stderr, "\n");
     }
@@ -587,9 +779,9 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
     sched->n_splits = 0;
 
     struct ggml_init_params params = {
-        /*.mem_size =   */ sizeof(sched->context_buffer),
-        /*.mem_buffer = */ sched->context_buffer,
-        /*.no_alloc =   */ true
+        /* .mem_size =   */ sizeof(sched->context_buffer),
+        /* .mem_buffer = */ sched->context_buffer,
+        /* .no_alloc =   */ true
     };
 
     if (sched->ctx != NULL) {
@@ -605,9 +797,9 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
             // do not overwrite user assignments
             continue;
         }
-        ggml_backend_t leaf_backend = ggml_get_backend(leaf);
+        ggml_backend_t leaf_backend = get_buffer_backend(sched, leaf->buffer);
         if (leaf_backend == NULL && leaf->view_src != NULL) {
-            leaf_backend = ggml_get_backend(leaf->view_src);
+            leaf_backend = get_buffer_backend(sched, leaf->view_src->buffer);
         }
         if (leaf_backend != NULL) {
             node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend);
@@ -649,7 +841,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
                         cur_prio = src_prio;
                         cur_size = src_size;
                         node_allocr = src_allocr;
-                        sprintf(causes[hash_id(node)], "2.src%d", j);
+                        SET_CAUSE(node, "2.src%d", j);
                     }
                 }
             }
@@ -733,7 +925,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
                     struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
                     sched->node_copies[id][cur_backend_id] = tensor_copy;
                     node_allocr(tensor_copy) = cur_allocr;
-                    ggml_backend_t backend = ggml_tallocr_get_buffer(cur_allocr)->backend;
+                    ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
                     ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
                 }
                 node->src[j] = sched->node_copies[id][cur_backend_id];
@@ -761,8 +953,8 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
             ggml_tallocr_t src_allocr = node_allocr(src);
             if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
                 fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
-                    node->name, node_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL",
-                    j, src->name, src_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL");
+                    node->name, node_allocr ? ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
+                    j, src->name, src_allocr ? ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL");
             }
         }
     }
@@ -773,7 +965,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
     struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false);
     for (int i = 0; i < sched->n_splits; i++) {
         struct ggml_backend_sched_split * split = &sched->splits[i];
-        split->graph = ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end);
+        split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
 
         // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
         for (int j = 0; j < split->n_inputs; j++) {
@@ -806,31 +998,29 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
 
     for (int i = 0; i < sched->n_splits; i++) {
         struct ggml_backend_sched_split * split = &splits[i];
-        ggml_backend_t split_backend = ggml_tallocr_get_buffer(split->tallocr)->backend;
+        ggml_backend_t split_backend = get_allocr_backend(sched, split->tallocr);
         int split_backend_id = sched_backend_prio(sched, split_backend);
 
         // copy the input tensors to the split backend
         uint64_t copy_start_us = ggml_time_us();
         for (int j = 0; j < split->n_inputs; j++) {
-            struct ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)];
-            if (split->inputs[j]->buffer == NULL) {
-                if (split->inputs[j]->view_src == NULL) {
-                    fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name);
+            struct ggml_tensor * input = split->inputs[j];
+            struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_backend_prio(sched, split_backend)];
+            if (input->buffer == NULL) {
+                if (input->view_src == NULL) {
+                    fprintf(stderr, "input %s has no buffer and no view_src\n", input->name);
                     exit(1);
                 }
-                struct ggml_tensor * view = split->inputs[j];
-                view->backend = view->view_src->backend;
-                view->buffer  = view->view_src->buffer;
-                view->data    = (char *)view->view_src->data + view->view_offs;
-                ggml_backend_buffer_init_tensor(ggml_backend_sched_get_buffer(sched, view->buffer->backend), view);
+                // FIXME: may need to use the sched buffer instead
+                ggml_backend_view_init(input->view_src->buffer, input);
             }
             if (input_cpy->buffer == NULL) {
                 fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
                 exit(1);
             }
-            GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend);
-            GGML_ASSERT(input_cpy->buffer->backend == split_backend);
-            ggml_backend_tensor_copy(split->inputs[j], input_cpy);
+            //GGML_ASSERT(input->buffer->backend != input_cpy->buffer->backend);
+            //GGML_ASSERT(input_cpy->buffer->backend == split_backend);
+            ggml_backend_tensor_copy(input, input_cpy);
         }
         // ggml_backend_synchronize(split_backend);
         int64_t copy_end_us = ggml_time_us();
@@ -843,7 +1033,7 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
 #endif
 
         uint64_t compute_start_us = ggml_time_us();
-        ggml_backend_graph_compute(split_backend, split->graph);
+        ggml_backend_graph_compute(split_backend, &split->graph);
         // ggml_backend_synchronize(split_backend);
         uint64_t compute_end_us = ggml_time_us();
         compute_us[split_backend_id] += compute_end_us - compute_start_us;
@@ -948,3 +1138,181 @@ void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml
     GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
     node_allocr(node) = sched->tallocs[backend_index];
 }
+
+// utils
+void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->buffer == NULL);
+    GGML_ASSERT(tensor->data == NULL);
+    GGML_ASSERT(tensor->view_src != NULL);
+    GGML_ASSERT(tensor->view_src->buffer != NULL);
+    GGML_ASSERT(tensor->view_src->data != NULL);
+
+    tensor->buffer = buffer;
+    tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
+    tensor->backend = tensor->view_src->backend;
+    ggml_backend_buffer_init_tensor(buffer, tensor);
+}
+
+void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
+    GGML_ASSERT(tensor->buffer == NULL);
+    GGML_ASSERT(tensor->data == NULL);
+    GGML_ASSERT(tensor->view_src == NULL);
+    GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
+    GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
+                (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
+
+    tensor->buffer = buffer;
+    tensor->data = addr;
+    ggml_backend_buffer_init_tensor(buffer, tensor);
+}
+
+static struct ggml_tensor * graph_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
+    struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) {
+
+    GGML_ASSERT(src != NULL);
+    GGML_ASSERT(src->data && "graph must be allocated");
+
+    size_t id = ggml_hash_insert(hash_set, src);
+    if (id == GGML_HASHTABLE_ALREADY_EXISTS) {
+        return node_copies[ggml_hash_find(hash_set, src)];
+    }
+
+    struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
+    if (src->view_src != NULL) {
+        dst->view_src = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
+        dst->view_offs = src->view_offs;
+    }
+    dst->op = src->op;
+    memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
+    ggml_set_name(dst, src->name);
+
+    // copy src
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        struct ggml_tensor * s = src->src[i];
+        if (s == NULL) {
+            break;
+        }
+        dst->src[i] = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
+    }
+
+    node_copies[id] = dst;
+    return dst;
+}
+
+static void graph_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
+    size_t id = ggml_hash_find(hash_set, src);
+    if (node_init[id]) {
+        return;
+    }
+    node_init[id] = true;
+
+    struct ggml_tensor * dst = node_copies[id];
+    if (dst->view_src != NULL) {
+        ggml_backend_view_init(dst->view_src->buffer, dst);
+    }
+    else {
+        ggml_backend_tensor_copy(src, dst);
+    }
+
+    // init src
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        struct ggml_tensor * s = src->src[i];
+        if (s == NULL) {
+            break;
+        }
+        graph_init_tensor(hash_set, node_copies, node_init, s);
+    }
+}
+
+struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
+    struct ggml_hash_set hash_set = {
+        /* .size = */ graph->visited_hash_table.size,
+        /* .keys = */ calloc(sizeof(hash_set.keys[0]) * graph->visited_hash_table.size, 1)
+    };
+    struct ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]) * hash_set.size, 1);
+    bool * node_init = calloc(sizeof(node_init[0]) * hash_set.size, 1);
+
+    struct ggml_init_params params = {
+        /* .mem_size   = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false),
+        /* .mem_buffer = */ NULL,
+        /* .no_alloc   = */ true
+    };
+
+    struct ggml_context * ctx_allocated = ggml_init(params);
+    struct ggml_context * ctx_unallocated = ggml_init(params);
+
+    // dup nodes
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
+    }
+
+    // allocate nodes
+    ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
+
+    //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
+
+    // copy data and init views
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        graph_init_tensor(hash_set, node_copies, node_init, node);
+    }
+
+    // build graph copy
+    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false);
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        struct ggml_tensor * node_copy = node_copies[ggml_hash_find(hash_set, node)];
+        graph_copy->nodes[i] = node_copy;
+    }
+    graph_copy->n_nodes = graph->n_nodes;
+
+    free(hash_set.keys);
+    free(node_copies);
+
+    return (struct ggml_backend_graph_copy) {
+        /* .buffer           = */ buffer,
+        /* .ctx_allocated    = */ ctx_allocated,
+        /* .ctx_unallocated  = */ ctx_unallocated,
+        /* .graph            = */ graph_copy,
+    };
+}
+
+void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
+    ggml_backend_buffer_free(copy.buffer);
+    ggml_free(copy.ctx_allocated);
+    ggml_free(copy.ctx_unallocated);
+}
+
+void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
+    struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
+    struct ggml_cgraph * g1 = graph;
+    struct ggml_cgraph * g2 = copy.graph;
+
+    assert(g1->n_nodes == g2->n_nodes);
+
+    for (int i = 0; i < g1->n_nodes; i++) {
+        //printf("eval %d/%d\n", i, g1->n_nodes);
+        struct ggml_tensor * t1 = g1->nodes[i];
+        struct ggml_tensor * t2 = g2->nodes[i];
+
+        assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
+
+        struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
+        struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
+
+        ggml_backend_graph_compute(backend1, &g1v);
+        ggml_backend_graph_compute(backend2, &g2v);
+
+        if (ggml_is_view_op(t1->op)) {
+            continue;
+        }
+
+        // compare results, calculate rms etc
+        if (!callback(i, t1, t2, user_data)) {
+            break;
+        }
+    }
+
+    ggml_backend_graph_copy_free(copy);
+}
index c0c9edd56dbc232b060afd62304d04bf4df45be3..6b194c2ad2c858fbb512e2fab44db378221c563a 100644 (file)
@@ -1,6 +1,7 @@
 #include <algorithm>
 #include <cstddef>
 #include <cstdint>
+#include <float.h>
 #include <limits>
 #include <stdint.h>
 #include <stdio.h>
@@ -189,7 +190,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
             fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
                 cudaGetErrorString(err_));                                              \
             fprintf(stderr, "current device: %d\n", id);                                \
-            exit(1);                                                                    \
+            GGML_ASSERT(!"CUDA error");                                                 \
         }                                                                               \
     } while (0)
 
@@ -203,7 +204,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
             fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \
                     err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \
             fprintf(stderr, "current device: %d\n", id);                                \
-            exit(1);                                                                    \
+            GGML_ASSERT(!"cuBLAS error");                                               \
         }                                                                               \
     } while (0)
 #else
@@ -215,7 +216,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
             cudaGetDevice(&id);                                                         \
             fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);  \
             fprintf(stderr, "current device: %d\n", id);                                \
-            exit(1);                                                                    \
+            GGML_ASSERT(!"cuBLAS error");                                               \
         }                                                                               \
     } while (0)
 #endif // CUDART_VERSION >= 11
@@ -4684,8 +4685,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
     }
 
     const int i = row*ncols + col;
-    // dst[i] = col > n_past + row ? -INFINITY : x[i];
-    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+    //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
+    //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
 }
 
 // the CUDA soft max implementation differs from the CPU implementation
@@ -7988,8 +7990,9 @@ void ggml_cuda_set_main_device(const int main_device) {
                 main_device, g_device_count, g_main_device);
         return;
     }
-    g_main_device = main_device;
-    if (g_device_count > 1) {
+
+    if (g_main_device != main_device && g_device_count > 1) {
+        g_main_device = main_device;
         cudaDeviceProp prop;
         CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device));
         fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name);
@@ -8148,27 +8151,16 @@ void ggml_cuda_get_device_description(int device, char * description, size_t des
 
 #define UNUSED GGML_UNUSED
 
-struct ggml_backend_context_cuda {
-};
-
-static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
-    return GGML_CUDA_NAME;
-
-    UNUSED(backend);
-}
-
-static void ggml_backend_cuda_free(ggml_backend_t backend) {
-    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
-    delete cuda_ctx;
-    delete backend;
-}
+// cuda buffer
 
 struct ggml_backend_buffer_context_cuda {
-    void * device;
-
+    int device;
+    void * dev_ptr = nullptr;
     ggml_tensor_extra_gpu * temp_tensor_extras = nullptr;
     size_t temp_tensor_extra_index = 0;
 
+    ggml_backend_buffer_context_cuda(int device, void * dev_ptr) : device(device), dev_ptr(dev_ptr) {}
+
     ~ggml_backend_buffer_context_cuda() {
         delete[] temp_tensor_extras;
     }
@@ -8189,41 +8181,20 @@ struct ggml_backend_buffer_context_cuda {
 
 static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
     ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
-    CUDA_CHECK(cudaFree(ctx->device));
+    CUDA_CHECK(cudaFree(ctx->dev_ptr));
     delete ctx;
 }
 
 static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
     ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
-    return ctx->device;
-}
-
-static size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
-    int64_t row_low = 0;
-    int64_t row_high = ggml_nrows(tensor);
-    int64_t nrows_split = row_high - row_low;
-
-    size_t size = ggml_nbytes_split(tensor, nrows_split);
-
-    int64_t ne0 = tensor->ne[0];
-
-    if (ggml_is_quantized(tensor->type)) {
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
-                * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
-        }
-    }
-
-    return size;
-
-    UNUSED(buffer);
+    return ctx->dev_ptr;
 }
 
 static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
     ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
 
     if (tensor->view_src != NULL && tensor->view_offs == 0) {
-        assert(tensor->view_src->buffer->backend == buffer->backend);
+        assert(tensor->view_src->buffer->buft == buffer->buft); // TODO
         tensor->backend = tensor->view_src->backend;
         tensor->extra = tensor->view_src->extra;
         return;
@@ -8231,7 +8202,7 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
 
     ggml_tensor_extra_gpu * extra = ctx->ggml_cuda_alloc_temp_tensor_extra();
 
-    extra->data_device[g_main_device] = tensor->data;
+    extra->data_device[ctx->device] = tensor->data;
 
     tensor->backend = GGML_BACKEND_GPU;
     tensor->extra = extra;
@@ -8243,64 +8214,207 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
         int64_t nrows_split = row_high - row_low;
 
         size_t original_size = ggml_nbytes_split(tensor, nrows_split);
-        size_t padded_size = ggml_backend_cuda_buffer_get_alloc_size(tensor->buffer, tensor);
+        size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
 
         if (padded_size > original_size && tensor->view_src == nullptr) {
-            CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[g_main_device][0]));
+            CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[ctx->device][0]));
         }
     }
 
     UNUSED(buffer);
 }
 
+static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
+
+    CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice));
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
+
+    CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost));
+
+    UNUSED(buffer);
+}
+
 static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
-    /* .free_buffer    = */ ggml_backend_cuda_buffer_free_buffer,
-    /* .get_base       = */ ggml_backend_cuda_buffer_get_base,
-    /* .get_alloc_size = */ ggml_backend_cuda_buffer_get_alloc_size,
-    /* .init_tensor    = */ ggml_backend_cuda_buffer_init_tensor,
-    /* .free_tensor    = */ NULL,
+    /* .free_buffer     = */ ggml_backend_cuda_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cuda_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_cuda_buffer_init_tensor,
+    /* .set_tensor      = */ ggml_backend_cuda_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cuda_buffer_get_tensor,
+    /* .cpy_tensor_from = */ NULL,
+    /* .cpy_tensor_to   = */ NULL,
 };
 
-static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) {
-    ggml_cuda_set_device(g_main_device);
+// cuda buffer type
+
+static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    int device = (int) (intptr_t) buft->context;
 
-    ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda;
+    ggml_cuda_set_device(device);
 
     size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
 
-    ggml_cuda_set_device(g_main_device);
-    CUDA_CHECK(cudaMalloc(&ctx->device, size));
+    void * dev_ptr;
+    CUDA_CHECK(cudaMalloc(&dev_ptr, size));
+
+    ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda(device, dev_ptr);
 
-    return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size);
+    return ggml_backend_buffer_init(buft, cuda_backend_buffer_interface, ctx, size);
 }
 
-static size_t ggml_backend_cuda_get_alignment(ggml_backend_t backend) {
+static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     return 128;
+
+    UNUSED(buft);
+}
+
+static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, ggml_tensor * tensor) {
+    int64_t row_low = 0;
+    int64_t row_high = ggml_nrows(tensor);
+    int64_t nrows_split = row_high - row_low;
+
+    size_t size = ggml_nbytes_split(tensor, nrows_split);
+
+    int64_t ne0 = tensor->ne[0];
+
+    if (ggml_is_quantized(tensor->type)) {
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
+                * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
+        }
+    }
+
+    return size;
+
+    UNUSED(buft);
+}
+
+static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return ggml_backend_is_cuda(backend);
+
+    UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i cuda_backend_buffer_type_interface = {
+    /* .alloc_buffer     = */ ggml_backend_cuda_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_cuda_buffer_type_get_alignment,
+    /* .get_alloc_size   = */ ggml_backend_cuda_buffer_type_get_alloc_size,
+    /* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend,
+};
+
+ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda[GGML_CUDA_MAX_DEVICES];
+    static bool ggml_backend_buffer_type_cuda_initialized = false;
+    if (!ggml_backend_buffer_type_cuda_initialized) {
+        for (int i = 0; i < GGML_CUDA_MAX_DEVICES; i++) {
+            ggml_backend_buffer_type_cuda[i] = {
+                /* .iface    = */ cuda_backend_buffer_type_interface,
+                /* .context  = */ (ggml_backend_buffer_type_context_t) (intptr_t) i,
+            };
+        }
+        ggml_backend_buffer_type_cuda_initialized = true;
+    }
+
+    return &ggml_backend_buffer_type_cuda[device];
+}
+
+// host buffer type
+
+static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
+    CUDA_CHECK(cudaFreeHost(ctx->dev_ptr));
+    delete ctx;
+}
+
+static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    void * ptr;
+    CUDA_CHECK(cudaMallocHost(&ptr, size));
+
+    // FIXME: this is a hack to avoid having to implement a new buffer type
+    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+    buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
+
+    return buffer;
+
+    UNUSED(buft);
+}
+
+struct ggml_backend_buffer_type_i cuda_backend_host_buffer_type_interface = {
+    /* .alloc_buffer     = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
+    /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
+    /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
+};
+
+ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda_host = {
+        /* .iface    = */ cuda_backend_host_buffer_type_interface,
+        /* .context  = */ nullptr,
+    };
+
+    return &ggml_backend_buffer_type_cuda_host;
+}
+
+// backend
+
+struct ggml_backend_context_cuda {
+    int device;
+};
+
+static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
+    return GGML_CUDA_NAME;
+
     UNUSED(backend);
 }
 
+static void ggml_backend_cuda_free(ggml_backend_t backend) {
+    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+    delete cuda_ctx;
+    delete backend;
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cuda_get_default_buffer_type(ggml_backend_t backend) {
+    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+    return ggml_backend_cuda_buffer_type(cuda_ctx->device);
+}
+
 static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+    GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
     GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
     GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
 
-    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[g_main_device][0]));
-
-    UNUSED(backend);
+    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0]));
 }
 
 static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+    GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
     GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
     GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
 
-    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
-
-    UNUSED(backend);
+    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0]));
 }
 
 static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
-    CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+    CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[cuda_ctx->device][0]));
 
     UNUSED(backend);
 }
@@ -8329,7 +8443,9 @@ static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_ba
 }
 
 static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-    ggml_cuda_set_device(g_main_device);
+    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+    ggml_cuda_set_main_device(cuda_ctx->device);
 
     ggml_compute_params params = {};
     params.type = GGML_TASK_COMPUTE;
@@ -8339,10 +8455,16 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
 
         if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE)
             continue;
+
         assert(node->backend == GGML_BACKEND_GPU);
+        assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
+        assert(node->extra != nullptr);
+
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             if (node->src[j] != nullptr) {
                 assert(node->src[j]->backend == GGML_BACKEND_GPU);
+                assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
+                assert(node->src[j]->extra != nullptr);
             }
         }
 
@@ -8379,27 +8501,76 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
     UNUSED(backend);
 }
 
+static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * tensor) {
+    switch (tensor->op) {
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(tensor)) {
+                case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_RELU:
+                    return true;
+                default:
+                    return false;
+            }
+            break;
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_NORM:
+        case GGML_OP_REPEAT:
+        case GGML_OP_GET_ROWS:
+        case GGML_OP_DUP:
+        case GGML_OP_ADD:
+        case GGML_OP_MUL:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_SCALE:
+        case GGML_OP_SQR:
+        case GGML_OP_CLAMP:
+        case GGML_OP_CPY:
+        case GGML_OP_CONT:
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_ROPE:
+        case GGML_OP_ALIBI:
+        case GGML_OP_IM2COL:
+            return true;
+        default:
+            return false;
+    }
+
+    UNUSED(backend);
+}
+
 static ggml_backend_i cuda_backend_i = {
-    /* .get_name            = */ ggml_backend_cuda_name,
-    /* .free                = */ ggml_backend_cuda_free,
-    /* .alloc_buffer        = */ ggml_backend_cuda_alloc_buffer,
-    /* .get_alignment       = */ ggml_backend_cuda_get_alignment,
-    /* .set_tensor_async    = */ ggml_backend_cuda_set_tensor_async,
-    /* .get_tensor_async    = */ ggml_backend_cuda_get_tensor_async,
-    /* .synchronize         = */ ggml_backend_cuda_synchronize,
-    /* .cpy_tensor_from     = */ nullptr,
-    /* .cpy_tensor_to       = */ nullptr,
-    /* .graph_plan_create   = */ ggml_backend_cuda_graph_plan_create,
-    /* .graph_plan_free     = */ ggml_backend_cuda_graph_plan_free,
-    /* .graph_plan_compute  = */ ggml_backend_cuda_graph_plan_compute,
-    /* .graph_compute       = */ ggml_backend_cuda_graph_compute,
-    /* .supports_op         = */ nullptr,
+    /* .get_name                = */ ggml_backend_cuda_name,
+    /* .free                    = */ ggml_backend_cuda_free,
+    /* .get_default_buffer_type = */ ggml_backend_cuda_get_default_buffer_type,
+    /* .set_tensor_async        = */ ggml_backend_cuda_set_tensor_async,
+    /* .get_tensor_async        = */ ggml_backend_cuda_get_tensor_async,
+    /* .cpy_tensor_from_async   = */ NULL,
+    /* .cpy_tensor_to_async     = */ NULL,
+    /* .synchronize             = */ ggml_backend_cuda_synchronize,
+    /* .graph_plan_create       = */ ggml_backend_cuda_graph_plan_create,
+    /* .graph_plan_free         = */ ggml_backend_cuda_graph_plan_free,
+    /* .graph_plan_compute      = */ ggml_backend_cuda_graph_plan_compute,
+    /* .graph_compute           = */ ggml_backend_cuda_graph_compute,
+    /* .supports_op             = */ ggml_backend_cuda_supports_op,
 };
 
-ggml_backend_t ggml_backend_cuda_init() {
+ggml_backend_t ggml_backend_cuda_init(int device) {
     ggml_init_cublas(); // TODO: remove from ggml.c
 
-    ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda;
+    if (device < 0 || device >= ggml_cuda_get_device_count()) {
+        fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
+        return nullptr;
+    }
+
+    ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda {
+        /* .device = */ device
+    };
 
     ggml_backend_t cuda_backend = new ggml_backend {
         /* .interface = */ cuda_backend_i,
@@ -8408,3 +8579,26 @@ ggml_backend_t ggml_backend_cuda_init() {
 
     return cuda_backend;
 }
+
+bool ggml_backend_is_cuda(ggml_backend_t backend) {
+    return backend->iface.get_name == ggml_backend_cuda_name;
+}
+
+static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) {
+    ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data);
+    return cuda_backend;
+
+    UNUSED(params);
+}
+
+static int ggml_backend_cuda_reg_devices() {
+    int device_count = ggml_cuda_get_device_count();
+    for (int i = 0; i < device_count; i++) {
+        char name[128];
+        snprintf(name, sizeof(name), "%s%d", GGML_CUDA_NAME, i);
+        ggml_backend_register(name, ggml_backend_reg_cuda_init, ggml_backend_cuda_buffer_type(i), (void *) (intptr_t) i);
+    }
+    return device_count;
+}
+
+GGML_CONSTRUCTOR(ggml_backend_cuda_reg_devices)
index 528e66c33a20738ce185744ff5780203c869e4ad..cdb0c0c41618a0692bdad716af60b64cd0b70ad2 100644 (file)
@@ -49,7 +49,15 @@ GGML_API int    ggml_cuda_get_device_count(void);
 GGML_API void   ggml_cuda_get_device_description(int device, char * description, size_t description_size);
 
 // backend API
-GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
+GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
+
+GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
+GGML_API int  ggml_backend_cuda_get_device(ggml_backend_t backend);
+
+GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
+
+// pinned host buffer for use with CPU backend for faster copies between CPU and GPU
+GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
 
 #ifdef  __cplusplus
 }
index 06c07339e926999ed1688da0cf40d21ab08ae8ae..1f5610a86cfd9e8347952dbf415a176f1fd23baf 100644 (file)
@@ -232,7 +232,7 @@ bool   ggml_hash_contains      (const struct ggml_hash_set hash_set, struct ggml
 // returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
 size_t ggml_hash_find          (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
 
-// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
+// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
 size_t ggml_hash_insert        (      struct ggml_hash_set hash_set, struct ggml_tensor * key);
 
 // return index, asserts if table is full
index be2731f8ba476728c50b4971a67a74a9437b981d..14ab5029cf8f7ad7354452b20103cfb3fe91a232 100644 (file)
@@ -98,6 +98,8 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
 
 GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
 
+GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
+
 GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
 
 #ifdef __cplusplus
index 4fe9cc487d29f6eb2cac0c4fb0a172d6b6d28136..37b291a9e31756e22de65ae1df1a837d9824b00e 100644 (file)
@@ -169,7 +169,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
 struct ggml_metal_context * ggml_metal_init(int n_cb) {
     GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
 
-    id <MTLDevice> device;
+    id<MTLDevice> device;
     NSString * s;
 
 #if TARGET_OS_OSX
@@ -245,6 +245,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         }
     }
 
+#if TARGET_OS_OSX
+    // print MTL GPU family:
+    GGML_METAL_LOG_INFO("%s: GPU name:   %s\n", __func__, [[ctx->device name] UTF8String]);
+
+    // determine max supported GPU family
+    // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+    // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+    for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
+        if ([ctx->device supportsFamily:i]) {
+            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
+            break;
+        }
+    }
+
+    GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+    GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
+    if (ctx->device.maxTransferRate != 0) {
+        GGML_METAL_LOG_INFO("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
+    } else {
+        GGML_METAL_LOG_INFO("%s: maxTransferRate               = built-in GPU\n", __func__);
+    }
+#endif
+
     // load kernels
     {
         NSError * error = nil;
@@ -331,29 +354,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #undef GGML_METAL_ADD_KERNEL
     }
 
-#if TARGET_OS_OSX
-    // print MTL GPU family:
-    GGML_METAL_LOG_INFO("%s: GPU name:   %s\n", __func__, [[ctx->device name] UTF8String]);
-
-    // determine max supported GPU family
-    // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
-    // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
-    for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
-        if ([ctx->device supportsFamily:i]) {
-            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
-            break;
-        }
-    }
-
-    GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
-    GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
-    if (ctx->device.maxTransferRate != 0) {
-        GGML_METAL_LOG_INFO("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
-    } else {
-        GGML_METAL_LOG_INFO("%s: maxTransferRate               = built-in GPU\n", __func__);
-    }
-#endif
-
     return ctx;
 }
 
@@ -471,6 +471,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
     return ctx->concur_list;
 }
 
+// temporarily defined here for compatibility between ggml-backend and the old API
+struct ggml_backend_metal_buffer_context {
+    void * data;
+
+    id<MTLBuffer> metal;
+};
+
 // finds the Metal buffer that contains the tensor data on the GPU device
 // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
 // Metal buffer based on the host memory pointer
@@ -480,8 +487,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
 
     const int64_t tsize = ggml_nbytes(t);
 
-    if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
-        ctx = t->buffer->backend->context;
+    // compatibility with ggml-backend
+    if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
+        struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
+
+        const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
+
+        GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
+
+        *offs = (size_t) ioffs;
+
+        return buf_ctx->metal;
     }
 
     // find the view that contains the tensor fully
@@ -1619,81 +1635,150 @@ void ggml_metal_graph_compute(
 
 // backend interface
 
-static const char * ggml_backend_metal_name(ggml_backend_t backend) {
-    return "Metal";
+static id<MTLDevice> g_backend_device = nil;
+static int g_backend_device_ref_count = 0;
 
-    UNUSED(backend);
+static id<MTLDevice> ggml_backend_metal_get_device(void) {
+    if (g_backend_device == nil) {
+        g_backend_device = MTLCreateSystemDefaultDevice();
+    }
+
+    g_backend_device_ref_count++;
+
+    return g_backend_device;
 }
 
-static void ggml_backend_metal_free(ggml_backend_t backend) {
-    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
-    ggml_metal_free(ctx);
-    free(backend);
+static void ggml_backend_metal_free_device(void) {
+    assert(g_backend_device_ref_count > 0);
+
+    g_backend_device_ref_count--;
+
+    if (g_backend_device_ref_count == 0) {
+        [g_backend_device release];
+        g_backend_device = nil;
+    }
 }
 
 static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
-    return (void *)buffer->context;
+    struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+    return ctx->data;
 }
 
 static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    free(buffer->context);
+    struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+    [ctx->metal release];
+    ggml_backend_metal_free_device();
+
+    free(ctx->data);
+    free(ctx);
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy((char *)tensor->data + offset, data, size);
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+
     UNUSED(buffer);
 }
 
 static struct ggml_backend_buffer_i metal_backend_buffer_i = {
-    /* .free_buffer    = */ ggml_backend_metal_buffer_free_buffer,
-    /* .get_base       = */ ggml_backend_metal_buffer_get_base,
-    /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
-    /* .init_tensor    = */ NULL, // no initialization required
-    /* .free_tensor    = */ NULL, // no cleanup required
+    /* .free_buffer     = */ ggml_backend_metal_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_metal_buffer_get_base,
+    /* .init_tensor     = */ NULL,
+    /* .set_tensor      = */ ggml_backend_metal_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_metal_buffer_get_tensor,
+    /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
+    /* .cpy_tensor_to   = */ ggml_backend_metal_buffer_cpy_tensor_to,
 };
 
-static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
-    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
 
-    void * data = ggml_metal_host_malloc(size);
+    const size_t size_page = sysconf(_SC_PAGESIZE);
 
-    // TODO: set proper name of the buffers
-    ggml_metal_add_buffer(ctx, "backend", data, size, 0);
+    size_t size_aligned = size;
+    if ((size_aligned % size_page) != 0) {
+        size_aligned += (size_page - (size_aligned % size_page));
+    }
+
+    ctx->data  = ggml_metal_host_malloc(size);
+    ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
+                    length:size_aligned
+                    options:MTLResourceStorageModeShared
+                    deallocator:nil];
 
-    return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
+    return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
 }
 
-static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
+static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     return 32;
-    UNUSED(backend);
+    UNUSED(buft);
 }
 
-static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-
-    memcpy((char *)tensor->data + offset, data, size);
+static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
 
-    UNUSED(backend);
+    GGML_UNUSED(buft);
 }
 
-static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-
-    memcpy(data, (const char *)tensor->data + offset, size);
+ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
+        /* .iface = */ {
+            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_metal_buffer_type_get_alignment,
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
+        },
+        /* .context = */ NULL,
+    };
 
-    UNUSED(backend);
+    return &ggml_backend_buffer_type_metal;
 }
 
-static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
+static const char * ggml_backend_metal_name(ggml_backend_t backend) {
+    return "Metal";
+
     UNUSED(backend);
 }
 
-static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+static void ggml_backend_metal_free(ggml_backend_t backend) {
+    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+    ggml_metal_free(ctx);
+    free(backend);
+}
 
+static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
     UNUSED(backend);
 }
 
-static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
+static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
+    return ggml_backend_metal_buffer_type();
 
     UNUSED(backend);
 }
@@ -1705,32 +1790,79 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
 }
 
 static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
-    return true;
+    switch (op->op) {
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_GELU:
+                    return true;
+                default:
+                    return false;
+            }
+            break;
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_CONCAT:
+        case GGML_OP_ADD:
+        case GGML_OP_MUL:
+        case GGML_OP_SCALE:
+        case GGML_OP_SQR:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_GET_ROWS:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_NORM:
+        case GGML_OP_ALIBI:
+        case GGML_OP_ROPE:
+        case GGML_OP_IM2COL:
+        case GGML_OP_DUP:
+        case GGML_OP_CPY:
+        case GGML_OP_CONT:
+            return true;
+        default:
+            return false;
+    }
+
     UNUSED(backend);
-    UNUSED(op);
 }
 
 static struct ggml_backend_i metal_backend_i = {
-    /* .get_name            = */ ggml_backend_metal_name,
-    /* .free                = */ ggml_backend_metal_free,
-    /* .alloc_buffer        = */ ggml_backend_metal_alloc_buffer,
-    /* .get_alignment       = */ ggml_backend_metal_get_alignment,
-    /* .set_tensor_async    = */ ggml_backend_metal_set_tensor_async,
-    /* .get_tensor_async    = */ ggml_backend_metal_get_tensor_async,
-    /* .synchronize         = */ ggml_backend_metal_synchronize,
-    /* .cpy_tensor_from     = */ ggml_backend_metal_cpy_tensor_from,
-    /* .cpy_tensor_to       = */ ggml_backend_metal_cpy_tensor_to,
-    /* .graph_plan_create   = */ NULL, // the metal implementation does not require creating graph plans atm
-    /* .graph_plan_free     = */ NULL,
-    /* .graph_plan_compute  = */ NULL,
-    /* .graph_compute       = */ ggml_backend_metal_graph_compute,
-    /* .supports_op         = */ ggml_backend_metal_supports_op,
+    /* .get_name                = */ ggml_backend_metal_name,
+    /* .free                    = */ ggml_backend_metal_free,
+    /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_from_async   = */ NULL,
+    /* .cpy_tensor_to_async     = */ NULL,
+    /* .synchronize             = */ ggml_backend_metal_synchronize,
+    /* .graph_plan_create       = */ NULL, // the metal implementation does not require creating graph plans atm
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_metal_graph_compute,
+    /* .supports_op             = */ ggml_backend_metal_supports_op,
 };
 
+// TODO: make a common log callback for all backends in ggml-backend
+static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
+    fprintf(stderr, "%s", msg);
+
+    UNUSED(level);
+    UNUSED(user_data);
+}
+
 ggml_backend_t ggml_backend_metal_init(void) {
-    struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
+    ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
 
-    ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
+    struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
+
+    if (ctx == NULL) {
+        return NULL;
+    }
 
     ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
 
@@ -1747,7 +1879,18 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
 }
 
 void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
     struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
 
     ggml_metal_set_n_cb(ctx, n_cb);
 }
+
+static ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
+    return ggml_backend_metal_init();
+
+    GGML_UNUSED(params);
+    GGML_UNUSED(user_data);
+}
+
+GGML_BACKEND_REGISTER("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL)
index bc9e3c9b31cfcc1cfc227c2bbaea2fa762a6c2ac..1b192b765f80a764c3cbaa767edef603f30d2ecf 100644 (file)
@@ -1752,6 +1752,24 @@ static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
+
+static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
+    "ABS",
+    "SGN",
+    "NEG",
+    "STEP",
+    "TANH",
+    "ELU",
+    "RELU",
+    "GELU",
+    "GELU_QUICK",
+    "SILU",
+    "LEAKY",
+};
+
+static_assert(GGML_UNARY_OP_COUNT == 11, "GGML_UNARY_OP_COUNT != 11");
+
+
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
 
@@ -2023,6 +2041,20 @@ const char * ggml_op_symbol(enum ggml_op op) {
     return GGML_OP_SYMBOL[op];
 }
 
+const char * ggml_unary_op_name(enum ggml_unary_op op) {
+    return GGML_UNARY_OP_NAME[op];
+}
+
+const char * ggml_op_desc(const struct ggml_tensor * t) {
+    if (t->op == GGML_OP_UNARY) {
+        enum ggml_unary_op uop = ggml_get_unary_op(t);
+        return ggml_unary_op_name(uop);
+    }
+    else {
+        return ggml_op_name(t->op);
+    }
+}
+
 size_t ggml_element_size(const struct ggml_tensor * tensor) {
     return ggml_type_size(tensor->type);
 }
@@ -8697,7 +8729,6 @@ static void ggml_compute_forward_gelu_f32(
     // row range for this thread
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
-
     for (int i1 = ir0; i1 < ir1; i1++) {
         ggml_vec_gelu_f32(nc,
                 (float *) ((char *) dst->data  + i1*( dst->nb[1])),
@@ -15204,12 +15235,8 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
     return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);
 }
 
-struct ggml_cgraph * ggml_graph_view(struct ggml_context * ctx, struct ggml_cgraph * cgraph0, int i0, int i1) {
-    const size_t obj_size = sizeof(struct ggml_cgraph);
-    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, obj_size);
-    struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
-
-    *cgraph = (struct ggml_cgraph) {
+struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) {
+    struct ggml_cgraph cgraph = {
         /*.size         =*/ 0,
         /*.n_nodes      =*/ i1 - i0,
         /*.n_leafs      =*/ 0,
@@ -15477,6 +15504,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                     {
                         n_tasks = n_threads;
                     } break;
+                default:
+                    GGML_ASSERT(false);
             }
             break;
         case GGML_OP_SILU_BACK:
index e069e4e6deecf2c0022a99df6af6a4370cb780c9..63b7fac00183a6d39ef9702d9421f01cf86a18b4 100644 (file)
@@ -384,3 +384,23 @@ add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE ggml)
 add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
 set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
+
+
+#
+# test-backend-buffer
+
+set(TEST_TARGET test-backend-buffer)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
+
+
+#
+# test-backend-ops
+
+set(TEST_TARGET test-backend-ops)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
diff --git a/tests/test-backend-buffer.cpp b/tests/test-backend-buffer.cpp
new file mode 100644 (file)
index 0000000..0110144
--- /dev/null
@@ -0,0 +1,84 @@
+#include <cstring>
+#include <ggml.h>
+#include <ggml-alloc.h>
+#include <ggml-backend.h>
+#include <ggml-backend-impl.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+
+static bool is_pow2(size_t x) {
+    return (x & (x - 1)) == 0;
+}
+
+static void test_buffer(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(ggml_backend_get_default_buffer_type(backend) == buft);
+
+    GGML_ASSERT(ggml_backend_buft_supports_backend(buft, backend));
+
+    //ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, 1024);
+    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, 1024);
+
+    GGML_ASSERT(buffer != NULL);
+
+    GGML_ASSERT(is_pow2(ggml_backend_buffer_get_alignment(buffer)));
+
+    GGML_ASSERT(ggml_backend_buffer_get_base(buffer) != NULL);
+
+    GGML_ASSERT(ggml_backend_buffer_get_size(buffer) >= 1024);
+
+    struct ggml_init_params params = {
+        /* .mem_size = */ 1024,
+        /* .mem_base = */ NULL,
+        /* .no_alloc = */ true,
+    };
+    struct ggml_context * ctx = ggml_init(params);
+
+    static const size_t n = 10;
+
+    struct ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n);
+
+    GGML_ASSERT(ggml_backend_buffer_get_alloc_size(buffer, tensor) >= n * sizeof(float));
+
+    ggml_tallocr_t allocr = ggml_tallocr_new_from_buffer(buffer);
+    ggml_tallocr_alloc(allocr, tensor);
+
+    GGML_ASSERT(tensor->data != NULL);
+
+    GGML_ASSERT(tensor->data >= ggml_backend_buffer_get_base(buffer));
+
+    float data[n];
+    for (size_t i = 0; i < n; i++) {
+        data[i] = (float) i;
+    }
+
+    ggml_backend_tensor_set(tensor, data, 0, sizeof(data));
+
+    float data2[n];
+    ggml_backend_tensor_get(tensor, data2, 0, sizeof(data2));
+
+    GGML_ASSERT(memcmp(data, data2, sizeof(data)) == 0);
+
+    ggml_tallocr_free(allocr);
+    ggml_backend_buffer_free(buffer);
+    ggml_free(ctx);
+}
+
+int main() {
+    // enumerate backends
+    printf("Testing %zu backends\n\n", ggml_backend_reg_get_count());
+
+    for (size_t i = 0; i < ggml_backend_reg_get_count(); i++) {
+        printf("Backend %zu/%zu (%s)\n", i + 1, ggml_backend_reg_get_count(), ggml_backend_reg_get_name(i));
+
+        ggml_backend_t backend = ggml_backend_reg_init_backend(i, NULL);
+        GGML_ASSERT(backend != NULL);
+        printf("  Backend name: %s\n", ggml_backend_name(backend));
+
+        test_buffer(backend, ggml_backend_reg_get_default_buffer_type(i));
+
+        ggml_backend_free(backend);
+
+        printf("  OK\n\n");
+    }
+}
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
new file mode 100644 (file)
index 0000000..b42b0fa
--- /dev/null
@@ -0,0 +1,934 @@
+#include <ggml.h>
+#include <ggml-alloc.h>
+#include <ggml-backend.h>
+#include <ggml-backend-impl.h>
+#include <array>
+#include <cstring>
+#include <cfloat>
+#include <functional>
+#include <memory>
+#include <random>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string>
+#include <vector>
+
+
+static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
+    size_t size = ggml_nelements(tensor);
+    std::vector<float> data(size);
+
+    std::random_device rd;
+    std::default_random_engine generator(rd());
+    std::uniform_real_distribution<float> distribution(min, max);
+
+    for (size_t i = 0; i < size; i++) {
+        data[i] = distribution(generator);
+    }
+
+    if (tensor->type == GGML_TYPE_F32) {
+        ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
+    } else if (tensor->type == GGML_TYPE_F16) {
+        std::vector<ggml_fp16_t> data16(size);
+        ggml_fp32_to_fp16_row(data.data(), data16.data(), size);
+        ggml_backend_tensor_set(tensor, data16.data(), 0, size * sizeof(ggml_fp16_t));
+    } else {
+        GGML_ASSERT(false);
+    }
+}
+
+static std::vector<float> tensor_to_float(const ggml_tensor * t) {
+    std::vector<float> tv;
+    tv.reserve(ggml_nelements(t));
+
+    std::vector<uint8_t> buf(ggml_nbytes(t));
+    ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
+
+    // access elements by index to avoid gaps in views
+    for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
+        for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
+            for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
+                for (int64_t i0 = 0; i0 < t->ne[0]; i0++) {
+                    size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0*t->nb[0];
+                    float v;
+                    if (t->type == GGML_TYPE_F16) {
+                        v = (float) ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]);
+                    } else if (t->type == GGML_TYPE_F32) {
+                        v = *(float *) &buf[i];
+                    } else {
+                        GGML_ASSERT(false);
+                    }
+                    tv.push_back(v);
+                }
+            }
+        }
+    }
+
+    return tv;
+}
+
+/*
+static double cosine_similarity(const float * v1, const float * v2, size_t n) {
+    double dot = 0.0;
+    double mag1 = 0.0;
+    double mag2 = 0.0;
+
+    for (size_t i = 0; i < n; i++) {
+        if (std::isnan(v1[i]) || std::isnan(v2[i])) {
+            return -1.0f;
+        }
+        if (std::isinf(v1[i]) && std::isinf(v2[i])) {
+            continue;
+        }
+        dot  += v1[i]*v2[i];
+        mag1 += v1[i]*v1[i];
+        mag2 += v2[i]*v2[i];
+    }
+
+    return dot/sqrt(mag1*mag2);
+}
+
+static float distance(const float * v1, const float * v2, size_t n) {
+    double d = 0.0;
+
+    for (size_t i = 0; i < n; i++) {
+        if (std::isnan(v1[i]) || std::isnan(v2[i])) {
+            return INFINITY;
+        }
+        if (std::isinf(v1[i]) && std::isinf(v2[i])) {
+            continue;
+        }
+        d += (v1[i] - v2[i])*(v1[i] - v2[i]);
+    }
+
+    return sqrt(d);
+}
+
+static float vec_len(const float * v, size_t n) {
+    double d = 0.0;
+
+    for (size_t i = 0; i < n; i++) {
+        if (std::isnan(v[i])) {
+            return INFINITY;
+        }
+        if (std::isinf(v[i])) {
+            continue;
+        }
+        d += v[i]*v[i];
+    }
+
+    return sqrt(d);
+}
+*/
+
+// normalized mean squared error = mse(a, b) / mse(a, 0)
+static double nmse(const float * a, const float * b, size_t n) {
+    double mse_a_b = 0.0;
+    double mse_a_0 = 0.0;
+
+    for (size_t i = 0; i < n; i++) {
+        float a_i = a[i];
+        float b_i = b[i];
+
+        mse_a_b += (a_i - b_i) * (a_i - b_i);
+        mse_a_0 += a_i * a_i;
+    }
+
+    return mse_a_b / mse_a_0;
+}
+
+// utils for printing the variables of the test cases
+#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
+
+template<typename T>
+static std::string var_to_str(const T & x) {
+    return std::to_string(x);
+}
+
+template<typename T, size_t N>
+static std::string var_to_str(const T (&x)[N]) {
+    std::string s = "[";
+    for (size_t i = 0; i < N; i++) {
+        if (i > 0) {
+            s += ",";
+        }
+        s += var_to_str(x[i]);
+    }
+    s += "]";
+    return s;
+}
+
+template<typename T, size_t N>
+static std::string var_to_str(const std::array<T, N> & x) {
+    std::string s = "[";
+    for (size_t i = 0; i < N; i++) {
+        if (i > 0) {
+            s += ",";
+        }
+        s += var_to_str(x[i]);
+    }
+    s += "]";
+    return s;
+}
+
+//static std::string var_to_str(ggml_unary_op unary_op) {
+//    return ggml_unary_op_name(unary_op);
+//}
+
+static std::string var_to_str(ggml_type type) {
+    return ggml_type_name(type);
+}
+
+#define VARS_TO_STR1(a) VAR_TO_STR(a)
+#define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b)
+#define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c)
+#define VARS_TO_STR4(a, b, c, d) VAR_TO_STR(a) + "," + VARS_TO_STR3(b, c, d)
+#define VARS_TO_STR5(a, b, c, d, e) VAR_TO_STR(a) + "," + VARS_TO_STR4(b, c, d, e)
+#define VARS_TO_STR6(a, b, c, d, e, f) VAR_TO_STR(a) + "," + VARS_TO_STR5(b, c, d, e, f)
+#define VARS_TO_STR7(a, b, c, d, e, f, g) VAR_TO_STR(a) + "," + VARS_TO_STR6(b, c, d, e, f, g)
+#define VARS_TO_STR8(a, b, c, d, e, f, g, h) VAR_TO_STR(a) + "," + VARS_TO_STR7(b, c, d, e, f, g, h)
+#define VARS_TO_STR9(a, b, c, d, e, f, g, h, i) VAR_TO_STR(a) + "," + VARS_TO_STR8(b, c, d, e, f, g, h, i)
+#define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
+#define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
+
+
+// accept FLT_MAX as infinity
+static bool isinf_or_max(float f) {
+    return std::isinf(f) || f == FLT_MAX || f == -FLT_MAX;
+}
+
+struct test_case {
+    virtual ~test_case() {}
+
+    virtual std::string vars() {
+        return "";
+    }
+
+    virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
+
+    virtual void initialize_tensors(ggml_context * ctx) {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t);
+        }
+    }
+
+    bool eval(ggml_backend_t backend1, ggml_backend_t backend2) {
+        ggml_init_params params = {
+            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
+            /* .mem_base = */ NULL,
+            /* .no_alloc = */ true,
+       };
+        ggml_context * ctx = ggml_init(params);
+
+        ggml_tensor * out = build_graph(ctx);
+
+        // check if backends support op
+        for (ggml_backend_t backend : {backend1, backend2}) {
+            if (!ggml_backend_supports_op(backend, out)) {
+                printf("  %s: not supported\n", ggml_op_desc(out));
+                ggml_free(ctx);
+                return true;
+            }
+        }
+
+        // allocate
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
+
+        // build graph
+        ggml_cgraph * gf = ggml_new_graph(ctx);
+        ggml_build_forward_expand(gf, out);
+
+        // randomize tensors
+        initialize_tensors(ctx);
+
+        // compare
+        bool ok = true;
+
+        auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
+            std::vector<float> f1 = tensor_to_float(t1);
+            std::vector<float> f2 = tensor_to_float(t2);
+            bool * ok = (bool *) user_data;
+
+            for (size_t i = 0; i < f1.size(); i++) {
+                // check for nans
+                if (std::isnan(f1[i]) || std::isnan(f2[i])) {
+                    printf("    Error: %s: NaN\n", ggml_op_desc(t1));
+                    *ok = false;
+                    return true;
+                }
+                // check for infs: both must be inf of the same sign, or both must be finite
+                if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
+                    if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
+                        if (std::signbit(f1[i]) != std::signbit(f2[i])) {
+                            printf("    Error: %s: inf sign mismatch: %f %f\n", ggml_op_desc(t1), f1[i], f2[i]);
+                            *ok = false;
+                            return true;
+                        }
+                    } else {
+                        printf("    Error: %s: inf mismatch: %f %f\n", ggml_op_desc(t1), f1[i], f2[i]);
+                        *ok = false;
+                        return true;
+                    }
+                }
+            }
+
+            double err = nmse(f1.data(), f2.data(), f1.size());
+            if (err > 1e-6) {
+                printf("    Error: %s: NMSE = %f\n", ggml_op_desc(t1), err);
+                *ok = false;
+            }
+            return true;
+       };
+
+        ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ok);
+
+        printf("  %s(%s): ", ggml_op_desc(out), vars().c_str());
+        if (ok) {
+            printf("\033[1;32mOK\033[0m\n");
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+
+        ggml_backend_buffer_free(buf);
+
+        ggml_free(ctx);
+
+        return ok;
+    }
+};
+
+// GGML_OP_UNARY
+struct test_unary : public test_case {
+    const ggml_unary_op op;
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_unary(ggml_unary_op op,
+            ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {128, 10, 10, 10})
+        : op(op), type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * in = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_unary(ctx, in, op);
+        return out;
+    }
+};
+
+// GGML_OP_GET_ROWS
+struct test_get_rows : public test_case {
+    const ggml_type type;
+    const int n; // cols
+    const int m; // rows
+    const int r; // rows to get
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, n, m, r);
+    }
+
+    test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3)
+        : type(type), n(n), m(m), r(r) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m);
+        ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r);
+        ggml_tensor * out = ggml_get_rows(ctx, in, rows);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                // rows
+                std::vector<int> data(r);
+                for (int i = 0; i < r; i++) {
+                    data[i] = rand() % m;
+                }
+                ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+// GGML_OP_REPEAT
+struct test_repeat : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const std::array<int, 4> nr;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, nr);
+    }
+
+    test_repeat(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            std::array<int, 4> nr = {2, 2, 2, 2})
+        : type(type), ne(ne), nr(nr) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * target = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_repeat(ctx, src, target);
+        return out;
+    }
+};
+
+// GGML_OP_DUP
+struct test_dup : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_dup(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 1})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_dup(ctx, src);
+        return out;
+    }
+};
+
+// GGML_OP_CPY
+struct test_cpy : public test_case {
+    const ggml_type type_src;
+    const ggml_type type_dst;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type_src, type_dst, ne);
+    }
+
+    test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 1})
+        : type_src(type_src), type_dst(type_dst), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
+        ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne.data());
+        ggml_tensor * out = ggml_cpy(ctx, src, dst);
+        return out;
+    }
+};
+
+// GGML_OP_CONT
+struct test_cont : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_cont(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 1})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
+        src = ggml_transpose(ctx, src);
+        ggml_tensor * out = ggml_cont(ctx, src);
+
+        return out;
+    }
+};
+
+// GGML_OP_ADD
+struct test_add : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const std::array<int,4> nr;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, nr);
+    }
+
+    test_add(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 1, 1},
+            std::array<int, 4> nr = {1, 2, 1, 1})
+        : type(type), ne(ne), nr(nr) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_add(ctx, a, b);
+        return out;
+    }
+};
+
+// GGML_OP_MUL
+struct test_mul : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const std::array<int,4> nr;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, nr);
+    }
+
+    test_mul(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 1, 1},
+            std::array<int, 4> nr = {1, 2, 1, 1})
+        : type(type), ne(ne), nr(nr) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_mul(ctx, a, b);
+        return out;
+    }
+};
+
+// GGML_OP_SCALE
+struct test_scale : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_scale(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * scale = ggml_new_tensor_1d(ctx, type, 1);
+        ggml_tensor * out = ggml_scale(ctx, a, scale);
+        return out;
+    }
+};
+
+// GGML_OP_NORM
+struct test_norm : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, eps);
+    }
+
+    test_norm(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {64, 10, 10, 10},
+            float eps = 1e-6f)
+        : type(type), ne(ne), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_norm(ctx, a, eps);
+        return out;
+    }
+};
+
+// GGML_OP_RMS_NORM
+struct test_rms_norm : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, eps);
+    }
+
+    test_rms_norm(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {64, 10, 10, 10},
+            float eps = 1e-6f)
+        : type(type), ne(ne), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
+        return out;
+    }
+};
+
+// GGML_OP_MUL_MAT
+struct test_mul_mat : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+    const std::array<int64_t, 2> bs; // dims 3 and 4
+    const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
+
+    std::string vars() override {
+        return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr);
+    }
+
+    test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32,
+            std::array<int64_t, 2> bs = {10, 10},
+            std::array<int64_t, 2> nr = {2, 2})
+        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        // C^T = A * B^T: (k, m) * (k, n) => (m, n)
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0]*nr[0], bs[1]*nr[1]);
+        ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
+        ggml_tensor * out = ggml_mul_mat(ctx, a, b);
+        return out;
+    }
+};
+
+// GGML_OP_SQR
+struct test_sqr : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sqr(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_sqr(ctx, a);
+        return out;
+    }
+};
+
+// GGML_OP_CLAMP
+struct test_clamp : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    float min;
+    float max;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, min, max);
+    }
+
+    test_clamp(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            float min = -0.5f, float max = 0.5f)
+        : type(type), ne(ne), min(min), max(max) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_clamp(ctx, a, min, max);
+        return out;
+    }
+};
+
+// GGML_OP_DIAG_MASK_INF
+struct test_diag_mask_inf : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const int n_past;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, n_past);
+    }
+
+    test_diag_mask_inf(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            int n_past = 5)
+        : type(type), ne(ne), n_past(n_past) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past);
+        return out;
+    }
+};
+
+// GGML_OP_SOFT_MAX
+struct test_soft_max : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_soft_max(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_soft_max(ctx, a);
+        return out;
+    }
+};
+
+// GGML_OP_ROPE
+struct test_rope : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    int n_dims;
+    int mode;
+    int n_ctx;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, ne, n_dims, mode, n_ctx);
+    }
+
+    test_rope(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 1},
+            int n_dims = 10, int mode = 0, int n_ctx = 512)
+        : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
+        ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                // pos
+                std::vector<int> data(ne[2]);
+                for (int i = 0; i < ne[2]; i++) {
+                    data[i] = rand() % n_ctx;
+                }
+                ggml_backend_tensor_set(t, data.data(), 0, ne[2] * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+// GGML_OP_ALIBI
+struct test_alibi : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    int n_past;
+    int n_head;
+    float bias_max;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, ne, n_past, n_head, bias_max);
+    }
+
+    test_alibi(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            int n_past = 512, int n_head = 10, float bias_max = 0.5f)
+        : type(type), ne(ne), n_past(n_past), n_head(n_head), bias_max(bias_max) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_alibi(ctx, a, n_past, n_head, bias_max);
+        return out;
+    }
+};
+
+// GGML_OP_IM2COL
+struct test_im2col : public test_case {
+    const ggml_type type_input;
+    const ggml_type type_kernel;
+    const std::array<int64_t, 4> ne_input;
+    const std::array<int64_t, 4> ne_kernel;
+    // stride
+    const int s0;
+    const int s1;
+    // padding
+    const int p0;
+    const int p1;
+    // dilatation
+    const int d0;
+    const int d1;
+    // mode
+    const bool is_2D;
+
+    std::string vars() override {
+        return VARS_TO_STR11(type_input, type_kernel, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);
+    }
+
+    test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16,
+            std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
+            std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
+            int s0 = 1, int s1 = 1,
+            int p0 = 1, int p1 = 1,
+            int d0 = 1, int d1 = 1,
+            bool is_2D = true)
+        : type_input(type_input), type_kernel(type_kernel), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
+        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
+        ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D);
+        return out;
+    }
+};
+
+// GGML_OP_CONCAT
+struct test_concat : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const int64_t b_ne2;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, b_ne2);
+    }
+
+    test_concat(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            int64_t b_ne2 = 10)
+        : type(type), ne(ne), b_ne2(b_ne2) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], b_ne2, ne[3]);
+        ggml_tensor * out = ggml_concat(ctx, a, b);
+        return out;
+    }
+};
+
+static bool test_backend(ggml_backend_t backend) {
+    ggml_backend_t backend_cpu = ggml_backend_cpu_init();
+
+    std::vector<std::unique_ptr<test_case>> test_cases;
+
+    // unary ops
+    for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
+        test_cases.emplace_back(new test_unary((ggml_unary_op) op));
+    }
+
+    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+        test_cases.emplace_back(new test_get_rows(type, 10, 5, 3));
+        test_cases.emplace_back(new test_get_rows(type, 16, 5, 3));
+    }
+
+    test_cases.emplace_back(new test_repeat());
+    test_cases.emplace_back(new test_dup());
+    test_cases.emplace_back(new test_cpy());
+    test_cases.emplace_back(new test_cont());
+
+    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1}));
+    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1}));
+    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1}));
+    //test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 1, 1, 1})); // broadcasting dim 0 is not supported
+    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 1, 1}));
+    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 1}));
+    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 2}));
+    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 2}));
+    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2}));
+    //test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2}));
+
+    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1}));
+    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1}));
+    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1}));
+    //test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 1, 1, 1})); // broadcasting dim 0 is not supported
+    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 1, 1}));
+    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 1}));
+    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 2}));
+    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 2}));
+    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2}));
+    //test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2}));
+
+    test_cases.emplace_back(new test_scale());
+
+    for (float eps : {1e-6f, 1e-5f, 1e-3f, 1e-1f}) {
+        test_cases.emplace_back(new test_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
+        test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
+    }
+
+    for (ggml_type t0 : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+        for (ggml_type t1 : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
+            // FIXME: CPU crashes on f16xf16
+            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, { 1,  1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10,  1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10,  1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 10}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 10}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 10}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 10}, {2, 2}));
+        }
+    }
+
+    test_cases.emplace_back(new test_sqr());
+    test_cases.emplace_back(new test_clamp());
+
+    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10,  1,  1}, 5));
+    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10,  1}, 5));
+    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
+
+    test_cases.emplace_back(new test_soft_max());
+
+    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+        test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512)); // llama 7B
+        test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512)); // llama 13B
+        test_cases.emplace_back(new test_rope(type, {128,  52, 10, 1}, 128, 0, 512)); // llama 30B
+        test_cases.emplace_back(new test_rope(type, {128,  64, 10, 1}, 128, 0, 512)); // llama 65B
+        test_cases.emplace_back(new test_rope(type, { 64,   1, 10, 1},  64, 2, 512)); // neox (falcon 7B)
+        test_cases.emplace_back(new test_rope(type, { 64,  71, 10, 1},  64, 2, 512)); // neox (falcon 7B)
+        test_cases.emplace_back(new test_rope(type, { 64,   8, 10, 1},  64, 2, 512)); // neox (falcon 40B)
+        test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512)); // neox (falcon 40B)
+        //test_cases.emplace_back(new test_rope(type, {80, 32, 10, 1}, 20, 2, 512)); // neox rope (stablelm) (TODO: enable after llama.cpp sync)
+    }
+
+    test_cases.emplace_back(new test_alibi());
+    test_cases.emplace_back(new test_im2col());
+    test_cases.emplace_back(new test_concat());
+
+    size_t n_ok = 0;
+    for (auto & test : test_cases) {
+        if (test->eval(backend, backend_cpu)) {
+            n_ok++;
+        }
+    }
+
+    printf("  %zu/%zu tests passed\n", n_ok, test_cases.size());
+
+    ggml_backend_free(backend_cpu);
+
+    return n_ok == test_cases.size();
+}
+
+int main() {
+    // enumerate backends
+    printf("Testing %zu backends\n\n", ggml_backend_reg_get_count());
+
+    size_t n_ok = 0;
+
+    for (size_t i = 0; i < ggml_backend_reg_get_count(); i++) {
+        printf("Backend %zu/%zu (%s)\n", i + 1, ggml_backend_reg_get_count(), ggml_backend_reg_get_name(i));
+
+        ggml_backend_t backend = ggml_backend_reg_init_backend(i, NULL);
+        GGML_ASSERT(backend != NULL);
+        printf("  Backend name: %s\n", ggml_backend_name(backend));
+
+        bool ok = test_backend(backend);
+
+        printf("  Backend %s: ", ggml_backend_name(backend));
+        if (ok) {
+            printf("\033[1;32mOK\033[0m\n");
+            n_ok++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+
+        printf("\n");
+
+        ggml_backend_free(backend);
+    }
+
+    printf("%zu/%zu backends passed\n", n_ok, ggml_backend_reg_get_count());
+    if (n_ok != ggml_backend_reg_get_count()) {
+        printf("\033[1;31mFAIL\033[0m\n");
+        return 1;
+    } else {
+        printf("\033[1;32mOK\033[0m\n");
+        return 0;
+    }
+}
index 98d45c5d0bed3cfd3a12b82cf7e064711b627747..0e5c75f7e7598447d9c9c0c45901c4487dd98320 100644 (file)
@@ -71,7 +71,7 @@ void load_model(test_model & model, bool use_gpu = false) {
 #ifdef GGML_USE_CUBLAS
     if (use_gpu) {
         fprintf(stderr, "%s: using CUDA backend\n", __func__);
-        model.backend = ggml_backend_cuda_init();
+        model.backend = ggml_backend_cuda_init(0);
         if (!model.backend) {
             fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
         }
index 66e6e32f47a24ccf0a4c8907063932c37524d3ed..64eee784809bfa9f8007b691126e073592f096ed 100644 (file)
@@ -71,7 +71,7 @@ void load_model(test_model & model, bool use_gpu = false) {
 #ifdef GGML_USE_CUBLAS
     if (use_gpu) {
         fprintf(stderr, "%s: using CUDA backend\n", __func__);
-        model.backend = ggml_backend_cuda_init();
+        model.backend = ggml_backend_cuda_init(0);
         if (!model.backend) {
             fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
         }
index d1406d4e755daeeced8939d1095b09c4e8f67eee..1811492c2231f2c64d0e9279e12bd7556ec7783d 100644 (file)
@@ -51,7 +51,7 @@ void load_model(test_model & model, float* a, float* b, int M, int N, int K, boo
 #ifdef GGML_USE_CUBLAS
     if (use_gpu) {
         fprintf(stderr, "%s: using CUDA backend\n", __func__);
-        model.backend = ggml_backend_cuda_init();
+        model.backend = ggml_backend_cuda_init(0);
         if (!model.backend) {
             fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
         }