]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
move BLAS to a separate backend (cont) (llama/6210)
authorslaren <redacted>
Sun, 16 Jun 2024 10:57:37 +0000 (13:57 +0300)
committerGeorgi Gerganov <redacted>
Sun, 16 Jun 2024 15:33:49 +0000 (18:33 +0300)
ggml-ci

.github/workflows/ci.yml
.gitignore
CMakeLists.txt
examples/common.h
examples/gpt-2/main-sched.cpp
ggml-blas.cpp [deleted file]
ggml-blas.h [deleted file]
src/CMakeLists.txt
src/ggml-blas.cpp [new file with mode: 0644]
src/ggml-blas.h [new file with mode: 0644]

index 4da1cd48a79c8344b89ed3b687a55784549bb206..817e669530be8bc0caf21dd30b88be4b2a2ad1d9 100644 (file)
@@ -61,7 +61,7 @@ jobs:
 
     - name: Configure CMake
       working-directory: ./build
-      run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON ..
+      run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON -DGGML_METAL=OFF ..
 
     - name: Build
       working-directory: ./build
@@ -112,7 +112,7 @@ jobs:
 
     - name: Configure CMake
       working-directory: ./build
-      run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON ..
+      run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON -DGGML_METAL=OFF ..
 
     - name: Build
       working-directory: ./build
index d588aa9984fb4ccd0b443199a716efad4a917896..dd2ca4b975e33c29affd5d64e0ceb9e3c5d740e3 100644 (file)
@@ -1,4 +1,5 @@
 build/
+build-blas/
 build-debug/
 build-release/
 build-sanitize-addr/
index 2af2122d419e0579b46b9e2f8773719273563dc9..f8f418bfa388c0f9e4aa404d74468f75dbfda051 100644 (file)
@@ -25,6 +25,16 @@ endif()
 
 # options
 
+if (APPLE)
+    set(GGML_METAL_DEFAULT ON)
+    set(GGML_BLAS_DEFAULT ON)
+    set(GGML_BLAS_VENDOR_DEFAULT "Apple")
+else()
+    set(GGML_METAL_DEFAULT OFF)
+    set(GGML_BLAS_DEFAULT OFF)
+    set(GGML_BLAS_VENDOR_DEFAULT "Generic")
+endif()
+
 option(BUILD_SHARED_LIBS            "ggml: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
 
 option(GGML_ALL_WARNINGS            "ggml: enable all compiler warnings"                   ON)
@@ -41,11 +51,13 @@ option(GGML_TEST_COVERAGE           "ggml: enable test coverage" OFF)
 
 option(GGML_PERF                    "ggml: enable perf timings"               OFF)
 option(GGML_NO_ACCELERATE           "ggml: disable Accelerate framework"      OFF)
-option(GGML_OPENBLAS                "ggml: use OpenBLAS"                      OFF)
+option(GGML_BLAS                    "ggml: use BLAS"                          ${GGML_BLAS_DEFAULT})
+set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
+                                    "ggml: BLAS library vendor")
 option(GGML_HIPBLAS                 "ggml: use hipBLAS"                       OFF)
 option(GGML_CUDA                    "ggml: use CUDA"                          OFF)
 option(GGML_CUBLAS                  "ggml: use CUDA (deprecated)"             OFF)
-option(GGML_METAL                   "ggml: use Metal"                         OFF)
+option(GGML_METAL                   "ggml: use Metal"                         ${GGML_METAL_DEFAULT})
 option(GGML_METAL_NDEBUG            "ggml: disable Metal debugging"           OFF)
 option(GGML_METAL_SHADER_DEBUG      "ggml: compile Metal with -fno-fast-math" OFF)
 option(GGML_METAL_EMBED_LIBRARY     "ggml: embed Metal library"               OFF)
index 2ed91ca9aa80dc2b3a152a278c71d197314201ed..79b98309553fc88d0413497d6c54b44403d47601 100644 (file)
@@ -21,7 +21,7 @@ struct gpt_params {
     int32_t n_threads    = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t n_predict    = 200;  // new tokens to predict
     int32_t n_parallel   = 1;    // number of parallel streams
-    int32_t n_batch      = 8;    // batch size for prompt processing
+    int32_t n_batch      = 32;   // batch size for prompt processing
     int32_t n_ctx        = 2048; // context size (this is the KV cache max size)
     int32_t n_gpu_layers = 0;    // number of layers to offlload to the GPU
 
index bdf3bff8233cbfc061279c0469868ef816462b74..11c72973d2273c52e7d3a8515e35cb764fa1687d 100644 (file)
 #include "ggml-metal.h"
 #endif
 
+#ifdef GGML_USE_BLAS
+#include "ggml-blas.h"
+#endif
+
 #include "common.h"
 #include "common-ggml.h"
 
@@ -131,6 +135,16 @@ void init_backends(gpt2_model & model, const gpt_params & params) {
         model.backends.push_back(gpu_backend);
     }
 
+#ifdef GGML_USE_BLAS
+    ggml_backend_t blas_backend = ggml_backend_blas_init();
+    if (!blas_backend) {
+        fprintf(stderr, "%s: failed to initialize BLAS backend\n", __func__);
+    } else {
+        ggml_backend_blas_set_n_threads(blas_backend, params.n_threads);
+        model.backends.push_back(blas_backend);
+    }
+#endif
+
     // always add the CPU backend as a fallback
     ggml_backend_t cpu_backend = ggml_backend_cpu_init();
     ggml_backend_cpu_set_n_threads(cpu_backend, params.n_threads);
diff --git a/ggml-blas.cpp b/ggml-blas.cpp
deleted file mode 100644 (file)
index d709a35..0000000
+++ /dev/null
@@ -1,363 +0,0 @@
-#include "ggml-blas.h"
-#include "ggml-backend-impl.h"
-
-#include <future>
-#include <vector>
-
-#if defined(GGML_USE_ACCELERATE)
-#   include <Accelerate/Accelerate.h>
-#elif defined(GGML_BLAS_USE_MKL)
-#   include <mkl.h>
-#else
-#   include <cblas.h>
-#   ifdef BLIS_ENABLE_CBLAS
-#       include <blis.h>
-#   endif
-#endif
-
-struct ggml_backend_blas_context {
-    int n_threads = GGML_DEFAULT_N_THREADS;
-    std::unique_ptr<char[]> work_data;
-    size_t work_size = 0;
-#ifndef GGML_USE_OPENMP
-    std::vector<std::future<void>> tasks;
-#endif
-};
-
-// helper function to determine if it is better to use BLAS or not
-// for large matrices, BLAS is faster
-static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) {
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    const int64_t ne10 = src1->ne[0];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-
-    // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) &&
-        ggml_is_contiguous(src1) &&
-        src1->type == GGML_TYPE_F32 &&
-        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
-
-        /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
-        return true;
-    }
-
-    return false;
-}
-
-static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const enum ggml_type type = src0->type;
-
-    GGML_ASSERT(ne0 == ne01);
-    GGML_ASSERT(ne1 == ne11);
-    GGML_ASSERT(ne2 == ne12);
-    GGML_ASSERT(ne3 == ne13);
-
-    // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == ggml_type_size(type));
-    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
-
-    const int64_t ne_plane      = ne01*ne00;
-    const size_t  desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
-
-    if (ctx->work_size < desired_wsize) {
-        ctx->work_data.reset(new char[desired_wsize]);
-        ctx->work_size = desired_wsize;
-    }
-    void * wdata = ctx->work_data.get();
-
-    // convert src0 to float
-    if (type != GGML_TYPE_F32) {
-        ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type);
-        ggml_to_float_t const to_float = type_traits.to_float;
-
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                const void  *       x      = (char *)  src0->data + i02*nb02          + i03*nb03;
-                      float * const wplane = (float *) wdata      + i02*ne_plane      + i03*ne02*ne_plane;
-
-                const int min_cols_per_thread = 4096;
-                const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
-                const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);
-
-#ifdef GGML_USE_OPENMP
-                #pragma omp parallel for num_threads(n_threads)
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
-                }
-#else
-                for (int i = 1; i < n_threads; i++) {
-                    const int64_t start =       i*ne01/n_threads;
-                    const int64_t end   = (i + 1)*ne01/n_threads;
-                    if (start < end) {
-                        ctx->tasks.push_back(std::async(std::launch::async, [=]() {
-                            for (int64_t i01 = start; i01 < end; i01++) {
-                                to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
-                            }
-                        }));
-                    }
-                }
-                {
-                    // reuse the current thread for the first task
-                    const int64_t start = 0;
-                    const int64_t end   = ne01/n_threads;
-                    for (int64_t i01 = start; i01 < end; i01++) {
-                        to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
-                    }
-                }
-#endif
-            }
-        }
-
-#ifndef GGML_USE_OPENMP
-        // wait for all tasks to finish
-        for (auto & task : ctx->tasks) {
-            task.get();
-        }
-        ctx->tasks.clear();
-#endif
-    }
-
-#if defined(OPENBLAS_VERSION)
-    openblas_set_num_threads(ctx->n_threads);
-#endif
-
-#if defined(BLIS_ENABLE_CBLAS)
-    bli_thread_set_num_threads(ctx->n_threads);
-#endif
-
-    for (int64_t i13 = 0; i13 < ne13; i13++) {
-        for (int64_t i12 = 0; i12 < ne12; i12++) {
-            const int64_t i03 = i13/r3;
-            const int64_t i02 = i12/r2;
-
-            const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
-            const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
-                  float * d = (float *) ((char *)  dst->data + i12*nb2  + i13*nb3);
-
-            if (type != GGML_TYPE_F32) {
-                x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
-            }
-
-            cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                        ne1, ne01, ne10,
-                        1.0f,   y, ne10,
-                                x, ne00,
-                        0.0f,   d, ne01);
-        }
-    }
-}
-
-static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(ne0  == ne00);
-    GGML_ASSERT(ne1  == ne10);
-    GGML_ASSERT(ne2  == ne02);
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne3  == ne13);
-    GGML_ASSERT(ne03 == ne13);
-
-    // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    // GGML_ASSERT(nb0 <= nb1);
-    // GGML_ASSERT(nb1 <= nb2);
-    // GGML_ASSERT(nb2 <= nb3);
-
-    // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
-    // src0: (k,n)
-    // src1: (k,m)
-    // dst:  (m,n)
-    //
-    // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
-    // Also expressed as (major,minor)
-    // a: (m,k): so src1 transposed
-    // b: (k,n): so src0
-    // c: (m,n)
-    //
-    // However, if ggml_is_transposed(src1) is true, then
-    // src1->data already contains a transposed version, so sgemm mustn't
-    // transpose it further.
-
-    int n = src0->ne[0];
-    int k = src0->ne[1];
-    int m = src1->ne[0];
-
-    CBLAS_TRANSPOSE transposeA;
-    int lda;
-
-    if (!ggml_is_transposed(src1)) {
-        transposeA = CblasTrans;
-        lda = m;
-    } else {
-        transposeA = CblasNoTrans;
-        lda = k;
-    }
-
-    float * a = (float *) ((char *) src1->data);
-    float * b = (float *) ((char *) src0->data);
-    float * c = (float *) ((char *) dst->data);
-
-    cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
-
-    GGML_UNUSED(ctx);
-}
-
-// backend interface
-
-GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) {
-    return "BLAS";
-
-    GGML_UNUSED(backend);
-}
-
-GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) {
-    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
-    delete ctx;
-    delete backend;
-}
-
-GGML_CALL static ggml_backend_buffer_type_t ggml_backend_blas_get_default_buffer_type(ggml_backend_t backend) {
-    return ggml_backend_cpu_buffer_type();
-
-    GGML_UNUSED(backend);
-}
-
-GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
-
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        struct ggml_tensor * node = cgraph->nodes[i];
-
-        switch (node->op) {
-            case GGML_OP_MUL_MAT:
-                ggml_backend_blas_mul_mat(ctx, node);
-                break;
-
-            case GGML_OP_OUT_PROD:
-                ggml_backend_blas_out_prod(ctx, node);
-                break;
-
-            case GGML_OP_NONE:
-            case GGML_OP_RESHAPE:
-            case GGML_OP_VIEW:
-            case GGML_OP_PERMUTE:
-            case GGML_OP_TRANSPOSE:
-                break;
-
-            default:
-                fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
-                GGML_ASSERT(false);
-        }
-    }
-
-    return GGML_STATUS_SUCCESS;
-
-    GGML_UNUSED(backend);
-}
-
-GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
-    const struct ggml_tensor * src0 = op->src[0];
-    const struct ggml_tensor * src1 = op->src[1];
-
-    return (op->op == GGML_OP_MUL_MAT  && ggml_backend_blas_use_blas(op)) ||
-           (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 &&
-                                          op->src[1]->type == GGML_TYPE_F32 &&
-                                          ggml_is_matrix(src0) &&
-                                          ggml_is_matrix(src1) &&
-                                          ggml_is_contiguous(src0) &&
-                                          (ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
-
-    GGML_UNUSED(backend);
-}
-
-GGML_CALL static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
-    return ggml_backend_buft_is_host(buft);
-
-    GGML_UNUSED(backend);
-}
-
-static struct ggml_backend_i blas_backend_i = {
-    /* .get_name                = */ ggml_backend_blas_name,
-    /* .free                    = */ ggml_backend_blas_free,
-    /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type,
-    /* .set_tensor_async        = */ NULL,
-    /* .get_tensor_async        = */ NULL,
-    /* .cpy_tensor_async        = */ NULL,
-    /* .synchronize             = */ NULL,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_blas_graph_compute,
-    /* .supports_op             = */ ggml_backend_blas_supports_op,
-    /* .supports_buft           = */ ggml_backend_blas_supports_buft,
-    /* .offload_op              = */ NULL,
-    /* .event_new               = */ NULL,
-    /* .event_free              = */ NULL,
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
-    /* .event_synchronize       = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_blas_guid(void) {
-    static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
-    return &guid;
-}
-
-ggml_backend_t ggml_backend_blas_init(void) {
-    ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
-
-    ggml_backend_t backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_blas_guid(),
-        /* .interface = */ blas_backend_i,
-        /* .context   = */ ctx,
-    };
-
-#if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
-    if (openblas_get_parallel() != OPENBLAS_OPENMP) {
-        fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
-    }
-#endif
-
-#if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
-    fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
-#endif
-
-    return backend;
-}
-
-GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
-}
-
-void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) {
-    GGML_ASSERT(ggml_backend_is_blas(backend_blas));
-
-    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
-    ctx->n_threads = n_threads;
-}
diff --git a/ggml-blas.h b/ggml-blas.h
deleted file mode 100644 (file)
index f2e37de..0000000
+++ /dev/null
@@ -1,23 +0,0 @@
-#pragma once
-
-#include "ggml.h"
-#include "ggml-backend.h"
-
-
-#ifdef  __cplusplus
-extern "C" {
-#endif
-
-// backend API
-GGML_API GGML_CALL ggml_backend_t ggml_backend_blas_init(void);
-
-GGML_API GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend);
-
-// number of threads used for conversion to float
-// for openblas and blis, this will also set the number of threads used for blas operations
-GGML_API GGML_CALL void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
-
-
-#ifdef  __cplusplus
-}
-#endif
index 25fb1ad170884ecd5335c3736654b816c87c5b45..0330c9b36e026a2f5ecbb925bae8a9206b905621 100644 (file)
@@ -152,28 +152,89 @@ if (APPLE AND NOT GGML_NO_ACCELERATE)
     endif()
 endif()
 
-if (GGML_OPENBLAS)
-    set(OPENBLAS_INCLUDE_SEARCH_PATHS
-        /usr/include
-        /usr/include/openblas
-        /usr/include/openblas-base
-        /usr/local/include
-        /usr/local/include/openblas
-        /usr/local/include/openblas-base
-        /opt/OpenBLAS/include
-        $ENV{OpenBLAS_HOME}
-        $ENV{OpenBLAS_HOME}/include
-        )
-    find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
-    find_library(OPENBLAS_LIB NAMES openblas libopenblas)
-    if (OPENBLAS_LIB)
-        message(STATUS "OpenBLAS found")
-
-        set(GGML_EXTRA_LIBS  ${GGML_EXTRA_LIBS}  ${OPENBLAS_LIB})
-        set(GGML_EXTRA_INCS  ${GGML_EXTRA_INCS}  ${OPENBLAS_INC})
-        set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
+if (GGML_BLAS)
+    if (GGML_STATIC)
+        set(BLA_STATIC ON)
+    endif()
+    #if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.22)
+    #    set(BLA_SIZEOF_INTEGER 8)
+    #endif()
+
+    set(BLA_VENDOR ${GGML_BLAS_VENDOR})
+    find_package(BLAS)
+
+    if (BLAS_FOUND)
+        message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}")
+
+        if (("${BLAS_INCLUDE_DIRS}" STREQUAL "") AND NOT (${GGML_BLAS_VENDOR} MATCHES "Apple"))
+            # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake.
+            # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
+            find_package(PkgConfig REQUIRED)
+            if (${GGML_BLAS_VENDOR} MATCHES "Generic")
+                pkg_check_modules(DepBLAS REQUIRED blas)
+            elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS")
+                # As of openblas v0.3.22, the 64-bit is named openblas64.pc
+                pkg_check_modules(DepBLAS openblas64)
+                if (NOT DepBLAS_FOUND)
+                    pkg_check_modules(DepBLAS REQUIRED openblas)
+                endif()
+            elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME")
+                pkg_check_modules(DepBLAS REQUIRED blis)
+            elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS")
+                pkg_check_modules(DepBLAS REQUIRED blas-atlas)
+            elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS")
+                pkg_check_modules(DepBLAS REQUIRED flexiblas_api)
+            elseif (${GGML_BLAS_VENDOR} MATCHES "Intel")
+                # all Intel* libraries share the same include path
+                pkg_check_modules(DepBLAS REQUIRED mkl-sdl)
+            elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC")
+                # this doesn't provide pkg-config
+                # suggest to assign BLAS_INCLUDE_DIRS on your own
+                if ("${NVHPC_VERSION}" STREQUAL "")
+                    message(WARNING "Better to set NVHPC_VERSION")
+                else()
+                    set(DepBLAS_FOUND ON)
+                    set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include")
+                endif()
+            endif()
+            if (DepBLAS_FOUND)
+                set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS})
+            else()
+                message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically"
+                " detected by pkgconfig, trying to find cblas.h from possible paths...")
+                find_path(BLAS_INCLUDE_DIRS
+                    NAMES cblas.h
+                    HINTS
+                        /usr/include
+                        /usr/local/include
+                        /usr/include/openblas
+                        /opt/homebrew/opt/openblas/include
+                        /usr/local/opt/openblas/include
+                        /usr/include/x86_64-linux-gnu/openblas/include
+                )
+            endif()
+        endif()
+
+        message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}")
+
+        add_compile_options(${BLAS_LINKER_FLAGS})
+
+        add_compile_definitions(GGML_USE_BLAS)
+
+        if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel"))
+            add_compile_definitions(GGML_BLAS_USE_MKL)
+        endif()
+
+        set(GGML_HEADERS_BLAS ggml-blas.h)
+        set(GGML_SOURCES_BLAS ggml-blas.cpp)
+
+        set(GGML_EXTRA_LIBS     ${GGML_EXTRA_LIBS}     ${BLAS_LIBRARIES})
+        set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS})
+        set(GGML_EXTRA_FLAGS    ${GGML_EXTRA_FLAGS} -DGGML_USE_BLAS)
     else()
-        message(WARNING "OpenBLAS not found")
+        message(WARNING "BLAS not found, please refer to "
+        "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
+        " to set correct GGML_BLAS_VENDOR")
     endif()
 endif()
 
@@ -513,9 +574,10 @@ add_library(${TARGET}
     ../include/ggml/ggml.h
     ../include/ggml/ggml-alloc.h
     ../include/ggml/ggml-backend.h
-    ${GGML_SOURCES_CUDA}
-    ${GGML_SOURCES_METAL}
-    ${GGML_SOURCES_RPC}
+    ${GGML_SOURCES_CUDA}  ${GGML_HEADERS_CUDA}
+    ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
+    ${GGML_SOURCES_RPC}   ${GGML_HEADERS_RPC}
+    ${GGML_SOURCES_BLAS}  ${GGML_HEADERS_BLAS}
     )
 
 target_include_directories(${TARGET} PUBLIC
diff --git a/src/ggml-blas.cpp b/src/ggml-blas.cpp
new file mode 100644 (file)
index 0000000..d709a35
--- /dev/null
@@ -0,0 +1,363 @@
+#include "ggml-blas.h"
+#include "ggml-backend-impl.h"
+
+#include <future>
+#include <vector>
+
+#if defined(GGML_USE_ACCELERATE)
+#   include <Accelerate/Accelerate.h>
+#elif defined(GGML_BLAS_USE_MKL)
+#   include <mkl.h>
+#else
+#   include <cblas.h>
+#   ifdef BLIS_ENABLE_CBLAS
+#       include <blis.h>
+#   endif
+#endif
+
+struct ggml_backend_blas_context {
+    int n_threads = GGML_DEFAULT_N_THREADS;
+    std::unique_ptr<char[]> work_data;
+    size_t work_size = 0;
+#ifndef GGML_USE_OPENMP
+    std::vector<std::future<void>> tasks;
+#endif
+};
+
+// helper function to determine if it is better to use BLAS or not
+// for large matrices, BLAS is faster
+static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    const int64_t ne10 = src1->ne[0];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+
+    // TODO: find the optimal values for these
+    if (ggml_is_contiguous(src0) &&
+        ggml_is_contiguous(src1) &&
+        src1->type == GGML_TYPE_F32 &&
+        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
+
+        /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
+        return true;
+    }
+
+    return false;
+}
+
+static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const enum ggml_type type = src0->type;
+
+    GGML_ASSERT(ne0 == ne01);
+    GGML_ASSERT(ne1 == ne11);
+    GGML_ASSERT(ne2 == ne12);
+    GGML_ASSERT(ne3 == ne13);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // broadcast factors
+    const int64_t r2 = ne12/ne02;
+    const int64_t r3 = ne13/ne03;
+
+    const int64_t ne_plane      = ne01*ne00;
+    const size_t  desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
+
+    if (ctx->work_size < desired_wsize) {
+        ctx->work_data.reset(new char[desired_wsize]);
+        ctx->work_size = desired_wsize;
+    }
+    void * wdata = ctx->work_data.get();
+
+    // convert src0 to float
+    if (type != GGML_TYPE_F32) {
+        ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type);
+        ggml_to_float_t const to_float = type_traits.to_float;
+
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                const void  *       x      = (char *)  src0->data + i02*nb02          + i03*nb03;
+                      float * const wplane = (float *) wdata      + i02*ne_plane      + i03*ne02*ne_plane;
+
+                const int min_cols_per_thread = 4096;
+                const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
+                const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);
+
+#ifdef GGML_USE_OPENMP
+                #pragma omp parallel for num_threads(n_threads)
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
+                }
+#else
+                for (int i = 1; i < n_threads; i++) {
+                    const int64_t start =       i*ne01/n_threads;
+                    const int64_t end   = (i + 1)*ne01/n_threads;
+                    if (start < end) {
+                        ctx->tasks.push_back(std::async(std::launch::async, [=]() {
+                            for (int64_t i01 = start; i01 < end; i01++) {
+                                to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
+                            }
+                        }));
+                    }
+                }
+                {
+                    // reuse the current thread for the first task
+                    const int64_t start = 0;
+                    const int64_t end   = ne01/n_threads;
+                    for (int64_t i01 = start; i01 < end; i01++) {
+                        to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
+                    }
+                }
+#endif
+            }
+        }
+
+#ifndef GGML_USE_OPENMP
+        // wait for all tasks to finish
+        for (auto & task : ctx->tasks) {
+            task.get();
+        }
+        ctx->tasks.clear();
+#endif
+    }
+
+#if defined(OPENBLAS_VERSION)
+    openblas_set_num_threads(ctx->n_threads);
+#endif
+
+#if defined(BLIS_ENABLE_CBLAS)
+    bli_thread_set_num_threads(ctx->n_threads);
+#endif
+
+    for (int64_t i13 = 0; i13 < ne13; i13++) {
+        for (int64_t i12 = 0; i12 < ne12; i12++) {
+            const int64_t i03 = i13/r3;
+            const int64_t i02 = i12/r2;
+
+            const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
+            const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
+                  float * d = (float *) ((char *)  dst->data + i12*nb2  + i13*nb3);
+
+            if (type != GGML_TYPE_F32) {
+                x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
+            }
+
+            cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne1, ne01, ne10,
+                        1.0f,   y, ne10,
+                                x, ne00,
+                        0.0f,   d, ne01);
+        }
+    }
+}
+
+static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(ne0  == ne00);
+    GGML_ASSERT(ne1  == ne10);
+    GGML_ASSERT(ne2  == ne02);
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne3  == ne13);
+    GGML_ASSERT(ne03 == ne13);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    // GGML_ASSERT(nb0 <= nb1);
+    // GGML_ASSERT(nb1 <= nb2);
+    // GGML_ASSERT(nb2 <= nb3);
+
+    // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
+    // src0: (k,n)
+    // src1: (k,m)
+    // dst:  (m,n)
+    //
+    // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
+    // Also expressed as (major,minor)
+    // a: (m,k): so src1 transposed
+    // b: (k,n): so src0
+    // c: (m,n)
+    //
+    // However, if ggml_is_transposed(src1) is true, then
+    // src1->data already contains a transposed version, so sgemm mustn't
+    // transpose it further.
+
+    int n = src0->ne[0];
+    int k = src0->ne[1];
+    int m = src1->ne[0];
+
+    CBLAS_TRANSPOSE transposeA;
+    int lda;
+
+    if (!ggml_is_transposed(src1)) {
+        transposeA = CblasTrans;
+        lda = m;
+    } else {
+        transposeA = CblasNoTrans;
+        lda = k;
+    }
+
+    float * a = (float *) ((char *) src1->data);
+    float * b = (float *) ((char *) src0->data);
+    float * c = (float *) ((char *) dst->data);
+
+    cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
+
+    GGML_UNUSED(ctx);
+}
+
+// backend interface
+
+GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) {
+    return "BLAS";
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) {
+    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
+    delete ctx;
+    delete backend;
+}
+
+GGML_CALL static ggml_backend_buffer_type_t ggml_backend_blas_get_default_buffer_type(ggml_backend_t backend) {
+    return ggml_backend_cpu_buffer_type();
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        switch (node->op) {
+            case GGML_OP_MUL_MAT:
+                ggml_backend_blas_mul_mat(ctx, node);
+                break;
+
+            case GGML_OP_OUT_PROD:
+                ggml_backend_blas_out_prod(ctx, node);
+                break;
+
+            case GGML_OP_NONE:
+            case GGML_OP_RESHAPE:
+            case GGML_OP_VIEW:
+            case GGML_OP_PERMUTE:
+            case GGML_OP_TRANSPOSE:
+                break;
+
+            default:
+                fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
+                GGML_ASSERT(false);
+        }
+    }
+
+    return GGML_STATUS_SUCCESS;
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+
+    return (op->op == GGML_OP_MUL_MAT  && ggml_backend_blas_use_blas(op)) ||
+           (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 &&
+                                          op->src[1]->type == GGML_TYPE_F32 &&
+                                          ggml_is_matrix(src0) &&
+                                          ggml_is_matrix(src1) &&
+                                          ggml_is_contiguous(src0) &&
+                                          (ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+    return ggml_backend_buft_is_host(buft);
+
+    GGML_UNUSED(backend);
+}
+
+static struct ggml_backend_i blas_backend_i = {
+    /* .get_name                = */ ggml_backend_blas_name,
+    /* .free                    = */ ggml_backend_blas_free,
+    /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_async        = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_blas_graph_compute,
+    /* .supports_op             = */ ggml_backend_blas_supports_op,
+    /* .supports_buft           = */ ggml_backend_blas_supports_buft,
+    /* .offload_op              = */ NULL,
+    /* .event_new               = */ NULL,
+    /* .event_free              = */ NULL,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+    /* .event_synchronize       = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_blas_guid(void) {
+    static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
+    return &guid;
+}
+
+ggml_backend_t ggml_backend_blas_init(void) {
+    ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
+
+    ggml_backend_t backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_blas_guid(),
+        /* .interface = */ blas_backend_i,
+        /* .context   = */ ctx,
+    };
+
+#if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
+    if (openblas_get_parallel() != OPENBLAS_OPENMP) {
+        fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
+    }
+#endif
+
+#if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
+    fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
+#endif
+
+    return backend;
+}
+
+GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
+}
+
+void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) {
+    GGML_ASSERT(ggml_backend_is_blas(backend_blas));
+
+    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
+    ctx->n_threads = n_threads;
+}
diff --git a/src/ggml-blas.h b/src/ggml-blas.h
new file mode 100644 (file)
index 0000000..f2e37de
--- /dev/null
@@ -0,0 +1,23 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+// backend API
+GGML_API GGML_CALL ggml_backend_t ggml_backend_blas_init(void);
+
+GGML_API GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend);
+
+// number of threads used for conversion to float
+// for openblas and blis, this will also set the number of threads used for blas operations
+GGML_API GGML_CALL void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
+
+
+#ifdef  __cplusplus
+}
+#endif