From: slaren Date: Sun, 16 Jun 2024 10:57:37 +0000 (+0300) Subject: move BLAS to a separate backend (cont) (llama/6210) X-Git-Tag: upstream/0.0.1642~577 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=169738dc6658c28dccad876e9db5ff2180940186;p=pkg%2Fggml%2Fsources%2Fggml move BLAS to a separate backend (cont) (llama/6210) ggml-ci --- diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4da1cd48..817e6695 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,7 +61,7 @@ jobs: - name: Configure CMake working-directory: ./build - run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON .. + run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON -DGGML_METAL=OFF .. - name: Build working-directory: ./build @@ -112,7 +112,7 @@ jobs: - name: Configure CMake working-directory: ./build - run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON .. + run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON -DGGML_METAL=OFF .. - name: Build working-directory: ./build diff --git a/.gitignore b/.gitignore index d588aa99..dd2ca4b9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ build/ +build-blas/ build-debug/ build-release/ build-sanitize-addr/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 2af2122d..f8f418bf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,16 @@ endif() # options +if (APPLE) + set(GGML_METAL_DEFAULT ON) + set(GGML_BLAS_DEFAULT ON) + set(GGML_BLAS_VENDOR_DEFAULT "Apple") +else() + set(GGML_METAL_DEFAULT OFF) + set(GGML_BLAS_DEFAULT OFF) + set(GGML_BLAS_VENDOR_DEFAULT "Generic") +endif() + option(BUILD_SHARED_LIBS "ggml: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT}) option(GGML_ALL_WARNINGS "ggml: enable all compiler warnings" ON) @@ -41,11 +51,13 @@ option(GGML_TEST_COVERAGE "ggml: enable test coverage" OFF) option(GGML_PERF "ggml: enable perf timings" OFF) option(GGML_NO_ACCELERATE "ggml: disable Accelerate framework" OFF) -option(GGML_OPENBLAS "ggml: use OpenBLAS" OFF) +option(GGML_BLAS "ggml: use BLAS" ${GGML_BLAS_DEFAULT}) +set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING + "ggml: BLAS library vendor") option(GGML_HIPBLAS "ggml: use hipBLAS" OFF) option(GGML_CUDA "ggml: use CUDA" OFF) option(GGML_CUBLAS "ggml: use CUDA (deprecated)" OFF) -option(GGML_METAL "ggml: use Metal" OFF) +option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF) option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" OFF) diff --git a/examples/common.h b/examples/common.h index 2ed91ca9..79b98309 100644 --- a/examples/common.h +++ b/examples/common.h @@ -21,7 +21,7 @@ struct gpt_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_predict = 200; // new tokens to predict int32_t n_parallel = 1; // number of parallel streams - int32_t n_batch = 8; // batch size for prompt processing + int32_t n_batch = 32; // batch size for prompt processing int32_t n_ctx = 2048; // context size (this is the KV cache max size) int32_t n_gpu_layers = 0; // number of layers to offlload to the GPU diff --git a/examples/gpt-2/main-sched.cpp b/examples/gpt-2/main-sched.cpp index bdf3bff8..11c72973 100644 --- a/examples/gpt-2/main-sched.cpp +++ b/examples/gpt-2/main-sched.cpp @@ -10,6 +10,10 @@ #include "ggml-metal.h" #endif +#ifdef GGML_USE_BLAS +#include "ggml-blas.h" +#endif + #include "common.h" #include "common-ggml.h" @@ -131,6 +135,16 @@ void init_backends(gpt2_model & model, const gpt_params & params) { model.backends.push_back(gpu_backend); } +#ifdef GGML_USE_BLAS + ggml_backend_t blas_backend = ggml_backend_blas_init(); + if (!blas_backend) { + fprintf(stderr, "%s: failed to initialize BLAS backend\n", __func__); + } else { + ggml_backend_blas_set_n_threads(blas_backend, params.n_threads); + model.backends.push_back(blas_backend); + } +#endif + // always add the CPU backend as a fallback ggml_backend_t cpu_backend = ggml_backend_cpu_init(); ggml_backend_cpu_set_n_threads(cpu_backend, params.n_threads); diff --git a/ggml-blas.cpp b/ggml-blas.cpp deleted file mode 100644 index d709a357..00000000 --- a/ggml-blas.cpp +++ /dev/null @@ -1,363 +0,0 @@ -#include "ggml-blas.h" -#include "ggml-backend-impl.h" - -#include -#include - -#if defined(GGML_USE_ACCELERATE) -# include -#elif defined(GGML_BLAS_USE_MKL) -# include -#else -# include -# ifdef BLIS_ENABLE_CBLAS -# include -# endif -#endif - -struct ggml_backend_blas_context { - int n_threads = GGML_DEFAULT_N_THREADS; - std::unique_ptr work_data; - size_t work_size = 0; -#ifndef GGML_USE_OPENMP - std::vector> tasks; -#endif -}; - -// helper function to determine if it is better to use BLAS or not -// for large matrices, BLAS is faster -static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - const int64_t ne10 = src1->ne[0]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - - // TODO: find the optimal values for these - if (ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && - src1->type == GGML_TYPE_F32 && - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { - - /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ - return true; - } - - return false; -} - -static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - const enum ggml_type type = src0->type; - - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == ggml_type_size(type)); - GGML_ASSERT(nb10 == ggml_type_size(src1->type)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; - - const int64_t ne_plane = ne01*ne00; - const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float); - - if (ctx->work_size < desired_wsize) { - ctx->work_data.reset(new char[desired_wsize]); - ctx->work_size = desired_wsize; - } - void * wdata = ctx->work_data.get(); - - // convert src0 to float - if (type != GGML_TYPE_F32) { - ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type); - ggml_to_float_t const to_float = type_traits.to_float; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - const void * x = (char *) src0->data + i02*nb02 + i03*nb03; - float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; - - const int min_cols_per_thread = 4096; - const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1); - const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1); - -#ifdef GGML_USE_OPENMP - #pragma omp parallel for num_threads(n_threads) - for (int64_t i01 = 0; i01 < ne01; i01++) { - to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); - } -#else - for (int i = 1; i < n_threads; i++) { - const int64_t start = i*ne01/n_threads; - const int64_t end = (i + 1)*ne01/n_threads; - if (start < end) { - ctx->tasks.push_back(std::async(std::launch::async, [=]() { - for (int64_t i01 = start; i01 < end; i01++) { - to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); - } - })); - } - } - { - // reuse the current thread for the first task - const int64_t start = 0; - const int64_t end = ne01/n_threads; - for (int64_t i01 = start; i01 < end; i01++) { - to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); - } - } -#endif - } - } - -#ifndef GGML_USE_OPENMP - // wait for all tasks to finish - for (auto & task : ctx->tasks) { - task.get(); - } - ctx->tasks.clear(); -#endif - } - -#if defined(OPENBLAS_VERSION) - openblas_set_num_threads(ctx->n_threads); -#endif - -#if defined(BLIS_ENABLE_CBLAS) - bli_thread_set_num_threads(ctx->n_threads); -#endif - - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; - - const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); - const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); - float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); - - if (type != GGML_TYPE_F32) { - x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; - } - - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne1, ne01, ne10, - 1.0f, y, ne10, - x, ne00, - 0.0f, d, ne01); - } - } -} - -static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(ne0 == ne00); - GGML_ASSERT(ne1 == ne10); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne3 == ne13); - GGML_ASSERT(ne03 == ne13); - - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == sizeof(float)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - // GGML_ASSERT(nb0 <= nb1); - // GGML_ASSERT(nb1 <= nb2); - // GGML_ASSERT(nb2 <= nb3); - - // Arguments to ggml_compute_forward_out_prod (expressed as major,minor) - // src0: (k,n) - // src1: (k,m) - // dst: (m,n) - // - // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f) - // Also expressed as (major,minor) - // a: (m,k): so src1 transposed - // b: (k,n): so src0 - // c: (m,n) - // - // However, if ggml_is_transposed(src1) is true, then - // src1->data already contains a transposed version, so sgemm mustn't - // transpose it further. - - int n = src0->ne[0]; - int k = src0->ne[1]; - int m = src1->ne[0]; - - CBLAS_TRANSPOSE transposeA; - int lda; - - if (!ggml_is_transposed(src1)) { - transposeA = CblasTrans; - lda = m; - } else { - transposeA = CblasNoTrans; - lda = k; - } - - float * a = (float *) ((char *) src1->data); - float * b = (float *) ((char *) src0->data); - float * c = (float *) ((char *) dst->data); - - cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n); - - GGML_UNUSED(ctx); -} - -// backend interface - -GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) { - return "BLAS"; - - GGML_UNUSED(backend); -} - -GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) { - ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context; - delete ctx; - delete backend; -} - -GGML_CALL static ggml_backend_buffer_type_t ggml_backend_blas_get_default_buffer_type(ggml_backend_t backend) { - return ggml_backend_cpu_buffer_type(); - - GGML_UNUSED(backend); -} - -GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context; - - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * node = cgraph->nodes[i]; - - switch (node->op) { - case GGML_OP_MUL_MAT: - ggml_backend_blas_mul_mat(ctx, node); - break; - - case GGML_OP_OUT_PROD: - ggml_backend_blas_out_prod(ctx, node); - break; - - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - break; - - default: - fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node)); - GGML_ASSERT(false); - } - } - - return GGML_STATUS_SUCCESS; - - GGML_UNUSED(backend); -} - -GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - const struct ggml_tensor * src0 = op->src[0]; - const struct ggml_tensor * src1 = op->src[1]; - - return (op->op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas(op)) || - (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 && - op->src[1]->type == GGML_TYPE_F32 && - ggml_is_matrix(src0) && - ggml_is_matrix(src1) && - ggml_is_contiguous(src0) && - (ggml_is_contiguous(src1) || ggml_is_transposed(src1))); - - GGML_UNUSED(backend); -} - -GGML_CALL static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - return ggml_backend_buft_is_host(buft); - - GGML_UNUSED(backend); -} - -static struct ggml_backend_i blas_backend_i = { - /* .get_name = */ ggml_backend_blas_name, - /* .free = */ ggml_backend_blas_free, - /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_blas_graph_compute, - /* .supports_op = */ ggml_backend_blas_supports_op, - /* .supports_buft = */ ggml_backend_blas_supports_buft, - /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, - /* .event_synchronize = */ NULL, -}; - -static ggml_guid_t ggml_backend_blas_guid(void) { - static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d }; - return &guid; -} - -ggml_backend_t ggml_backend_blas_init(void) { - ggml_backend_blas_context * ctx = new ggml_backend_blas_context; - - ggml_backend_t backend = new ggml_backend { - /* .guid = */ ggml_backend_blas_guid(), - /* .interface = */ blas_backend_i, - /* .context = */ ctx, - }; - -#if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP) - if (openblas_get_parallel() != OPENBLAS_OPENMP) { - fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__); - } -#endif - -#if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP) - fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__); -#endif - - return backend; -} - -GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend) { - return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid()); -} - -void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) { - GGML_ASSERT(ggml_backend_is_blas(backend_blas)); - - ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context; - ctx->n_threads = n_threads; -} diff --git a/ggml-blas.h b/ggml-blas.h deleted file mode 100644 index f2e37de0..00000000 --- a/ggml-blas.h +++ /dev/null @@ -1,23 +0,0 @@ -#pragma once - -#include "ggml.h" -#include "ggml-backend.h" - - -#ifdef __cplusplus -extern "C" { -#endif - -// backend API -GGML_API GGML_CALL ggml_backend_t ggml_backend_blas_init(void); - -GGML_API GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend); - -// number of threads used for conversion to float -// for openblas and blis, this will also set the number of threads used for blas operations -GGML_API GGML_CALL void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads); - - -#ifdef __cplusplus -} -#endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 25fb1ad1..0330c9b3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -152,28 +152,89 @@ if (APPLE AND NOT GGML_NO_ACCELERATE) endif() endif() -if (GGML_OPENBLAS) - set(OPENBLAS_INCLUDE_SEARCH_PATHS - /usr/include - /usr/include/openblas - /usr/include/openblas-base - /usr/local/include - /usr/local/include/openblas - /usr/local/include/openblas-base - /opt/OpenBLAS/include - $ENV{OpenBLAS_HOME} - $ENV{OpenBLAS_HOME}/include - ) - find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS}) - find_library(OPENBLAS_LIB NAMES openblas libopenblas) - if (OPENBLAS_LIB) - message(STATUS "OpenBLAS found") - - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ${OPENBLAS_LIB}) - set(GGML_EXTRA_INCS ${GGML_EXTRA_INCS} ${OPENBLAS_INC}) - set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_OPENBLAS) +if (GGML_BLAS) + if (GGML_STATIC) + set(BLA_STATIC ON) + endif() + #if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.22) + # set(BLA_SIZEOF_INTEGER 8) + #endif() + + set(BLA_VENDOR ${GGML_BLAS_VENDOR}) + find_package(BLAS) + + if (BLAS_FOUND) + message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") + + if (("${BLAS_INCLUDE_DIRS}" STREQUAL "") AND NOT (${GGML_BLAS_VENDOR} MATCHES "Apple")) + # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. + # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 + find_package(PkgConfig REQUIRED) + if (${GGML_BLAS_VENDOR} MATCHES "Generic") + pkg_check_modules(DepBLAS REQUIRED blas) + elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS") + # As of openblas v0.3.22, the 64-bit is named openblas64.pc + pkg_check_modules(DepBLAS openblas64) + if (NOT DepBLAS_FOUND) + pkg_check_modules(DepBLAS REQUIRED openblas) + endif() + elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME") + pkg_check_modules(DepBLAS REQUIRED blis) + elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS") + pkg_check_modules(DepBLAS REQUIRED blas-atlas) + elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS") + pkg_check_modules(DepBLAS REQUIRED flexiblas_api) + elseif (${GGML_BLAS_VENDOR} MATCHES "Intel") + # all Intel* libraries share the same include path + pkg_check_modules(DepBLAS REQUIRED mkl-sdl) + elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC") + # this doesn't provide pkg-config + # suggest to assign BLAS_INCLUDE_DIRS on your own + if ("${NVHPC_VERSION}" STREQUAL "") + message(WARNING "Better to set NVHPC_VERSION") + else() + set(DepBLAS_FOUND ON) + set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") + endif() + endif() + if (DepBLAS_FOUND) + set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) + else() + message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" + " detected by pkgconfig, trying to find cblas.h from possible paths...") + find_path(BLAS_INCLUDE_DIRS + NAMES cblas.h + HINTS + /usr/include + /usr/local/include + /usr/include/openblas + /opt/homebrew/opt/openblas/include + /usr/local/opt/openblas/include + /usr/include/x86_64-linux-gnu/openblas/include + ) + endif() + endif() + + message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") + + add_compile_options(${BLAS_LINKER_FLAGS}) + + add_compile_definitions(GGML_USE_BLAS) + + if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel")) + add_compile_definitions(GGML_BLAS_USE_MKL) + endif() + + set(GGML_HEADERS_BLAS ggml-blas.h) + set(GGML_SOURCES_BLAS ggml-blas.cpp) + + set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ${BLAS_LIBRARIES}) + set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) + set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_BLAS) else() - message(WARNING "OpenBLAS not found") + message(WARNING "BLAS not found, please refer to " + "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" + " to set correct GGML_BLAS_VENDOR") endif() endif() @@ -513,9 +574,10 @@ add_library(${TARGET} ../include/ggml/ggml.h ../include/ggml/ggml-alloc.h ../include/ggml/ggml-backend.h - ${GGML_SOURCES_CUDA} - ${GGML_SOURCES_METAL} - ${GGML_SOURCES_RPC} + ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} + ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} + ${GGML_SOURCES_RPC} ${GGML_HEADERS_RPC} + ${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS} ) target_include_directories(${TARGET} PUBLIC diff --git a/src/ggml-blas.cpp b/src/ggml-blas.cpp new file mode 100644 index 00000000..d709a357 --- /dev/null +++ b/src/ggml-blas.cpp @@ -0,0 +1,363 @@ +#include "ggml-blas.h" +#include "ggml-backend-impl.h" + +#include +#include + +#if defined(GGML_USE_ACCELERATE) +# include +#elif defined(GGML_BLAS_USE_MKL) +# include +#else +# include +# ifdef BLIS_ENABLE_CBLAS +# include +# endif +#endif + +struct ggml_backend_blas_context { + int n_threads = GGML_DEFAULT_N_THREADS; + std::unique_ptr work_data; + size_t work_size = 0; +#ifndef GGML_USE_OPENMP + std::vector> tasks; +#endif +}; + +// helper function to determine if it is better to use BLAS or not +// for large matrices, BLAS is faster +static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + if (ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && + src1->type == GGML_TYPE_F32 && + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { + + /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ + return true; + } + + return false; +} + +static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const enum ggml_type type = src0->type; + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + const int64_t ne_plane = ne01*ne00; + const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float); + + if (ctx->work_size < desired_wsize) { + ctx->work_data.reset(new char[desired_wsize]); + ctx->work_size = desired_wsize; + } + void * wdata = ctx->work_data.get(); + + // convert src0 to float + if (type != GGML_TYPE_F32) { + ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type); + ggml_to_float_t const to_float = type_traits.to_float; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; + + const int min_cols_per_thread = 4096; + const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1); + const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1); + +#ifdef GGML_USE_OPENMP + #pragma omp parallel for num_threads(n_threads) + for (int64_t i01 = 0; i01 < ne01; i01++) { + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); + } +#else + for (int i = 1; i < n_threads; i++) { + const int64_t start = i*ne01/n_threads; + const int64_t end = (i + 1)*ne01/n_threads; + if (start < end) { + ctx->tasks.push_back(std::async(std::launch::async, [=]() { + for (int64_t i01 = start; i01 < end; i01++) { + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); + } + })); + } + } + { + // reuse the current thread for the first task + const int64_t start = 0; + const int64_t end = ne01/n_threads; + for (int64_t i01 = start; i01 < end; i01++) { + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); + } + } +#endif + } + } + +#ifndef GGML_USE_OPENMP + // wait for all tasks to finish + for (auto & task : ctx->tasks) { + task.get(); + } + ctx->tasks.clear(); +#endif + } + +#if defined(OPENBLAS_VERSION) + openblas_set_num_threads(ctx->n_threads); +#endif + +#if defined(BLIS_ENABLE_CBLAS) + bli_thread_set_num_threads(ctx->n_threads); +#endif + + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; + + const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); + const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); + float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); + + if (type != GGML_TYPE_F32) { + x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; + } + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne1, ne01, ne10, + 1.0f, y, ne10, + x, ne00, + 0.0f, d, ne01); + } + } +} + +static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne3 == ne13); + GGML_ASSERT(ne03 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + // Arguments to ggml_compute_forward_out_prod (expressed as major,minor) + // src0: (k,n) + // src1: (k,m) + // dst: (m,n) + // + // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f) + // Also expressed as (major,minor) + // a: (m,k): so src1 transposed + // b: (k,n): so src0 + // c: (m,n) + // + // However, if ggml_is_transposed(src1) is true, then + // src1->data already contains a transposed version, so sgemm mustn't + // transpose it further. + + int n = src0->ne[0]; + int k = src0->ne[1]; + int m = src1->ne[0]; + + CBLAS_TRANSPOSE transposeA; + int lda; + + if (!ggml_is_transposed(src1)) { + transposeA = CblasTrans; + lda = m; + } else { + transposeA = CblasNoTrans; + lda = k; + } + + float * a = (float *) ((char *) src1->data); + float * b = (float *) ((char *) src0->data); + float * c = (float *) ((char *) dst->data); + + cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n); + + GGML_UNUSED(ctx); +} + +// backend interface + +GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) { + return "BLAS"; + + GGML_UNUSED(backend); +} + +GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) { + ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context; + delete ctx; + delete backend; +} + +GGML_CALL static ggml_backend_buffer_type_t ggml_backend_blas_get_default_buffer_type(ggml_backend_t backend) { + return ggml_backend_cpu_buffer_type(); + + GGML_UNUSED(backend); +} + +GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context; + + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + switch (node->op) { + case GGML_OP_MUL_MAT: + ggml_backend_blas_mul_mat(ctx, node); + break; + + case GGML_OP_OUT_PROD: + ggml_backend_blas_out_prod(ctx, node); + break; + + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + break; + + default: + fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node)); + GGML_ASSERT(false); + } + } + + return GGML_STATUS_SUCCESS; + + GGML_UNUSED(backend); +} + +GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + + return (op->op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas(op)) || + (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + ggml_is_matrix(src0) && + ggml_is_matrix(src1) && + ggml_is_contiguous(src0) && + (ggml_is_contiguous(src1) || ggml_is_transposed(src1))); + + GGML_UNUSED(backend); +} + +GGML_CALL static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { + return ggml_backend_buft_is_host(buft); + + GGML_UNUSED(backend); +} + +static struct ggml_backend_i blas_backend_i = { + /* .get_name = */ ggml_backend_blas_name, + /* .free = */ ggml_backend_blas_free, + /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_blas_graph_compute, + /* .supports_op = */ ggml_backend_blas_supports_op, + /* .supports_buft = */ ggml_backend_blas_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .event_synchronize = */ NULL, +}; + +static ggml_guid_t ggml_backend_blas_guid(void) { + static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d }; + return &guid; +} + +ggml_backend_t ggml_backend_blas_init(void) { + ggml_backend_blas_context * ctx = new ggml_backend_blas_context; + + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_blas_guid(), + /* .interface = */ blas_backend_i, + /* .context = */ ctx, + }; + +#if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP) + if (openblas_get_parallel() != OPENBLAS_OPENMP) { + fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__); + } +#endif + +#if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP) + fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__); +#endif + + return backend; +} + +GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid()); +} + +void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) { + GGML_ASSERT(ggml_backend_is_blas(backend_blas)); + + ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context; + ctx->n_threads = n_threads; +} diff --git a/src/ggml-blas.h b/src/ggml-blas.h new file mode 100644 index 00000000..f2e37de0 --- /dev/null +++ b/src/ggml-blas.h @@ -0,0 +1,23 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + + +#ifdef __cplusplus +extern "C" { +#endif + +// backend API +GGML_API GGML_CALL ggml_backend_t ggml_backend_blas_init(void); + +GGML_API GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend); + +// number of threads used for conversion to float +// for openblas and blis, this will also set the number of threads used for blas operations +GGML_API GGML_CALL void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads); + + +#ifdef __cplusplus +} +#endif