]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add CLBlast support (#1164)
author0cc4m <redacted>
Fri, 28 Apr 2023 14:57:16 +0000 (16:57 +0200)
committerGitHub <redacted>
Fri, 28 Apr 2023 14:57:16 +0000 (17:57 +0300)
* Allow use of OpenCL GPU-based BLAS using ClBlast instead of OpenBLAS for context processing

* Improve ClBlast implementation, avoid recreating buffers, remove redundant transfers

* Finish merge of ClBlast support

* Move CLBlast implementation to separate file

Add buffer reuse code (adapted from slaren's cuda implementation)

* Add q4_2 and q4_3 CLBlast support, improve code

* Double CLBlast speed by disabling OpenBLAS thread workaround

Co-authored-by: Concedo <redacted>
Co-authored-by: slaren <redacted>
* Fix device selection env variable names

* Fix cast in opencl kernels

* Add CLBlast to CMakeLists.txt

* Replace buffer pool with static buffers a, b, qb, c

Fix compile warnings

* Fix typos, use GGML_TYPE defines, improve code

* Improve btype dequant kernel selection code, add error if type is unsupported

* Improve code quality

* Move internal stuff out of header
* Use internal enums instead of CLBlast enums
* Remove leftover C++ includes and defines
* Make event use easier to read

Co-authored-by: Henri Vasserman <redacted>
* Use c compiler for opencl files

* Simplify code, fix include

* First check error, then release event

* Make globals static, fix indentation

* Rename dequant kernels file to conform with other file names

* Fix import cl file name

---------

Co-authored-by: Concedo <redacted>
Co-authored-by: slaren <redacted>
Co-authored-by: Henri Vasserman <redacted>
Co-authored-by: Georgi Gerganov <redacted>
CMakeLists.txt
Makefile
ggml-opencl-dequant.cl [new file with mode: 0644]
ggml-opencl.c [new file with mode: 0644]
ggml-opencl.h [new file with mode: 0644]
ggml.c
ggml.h
llama.cpp

index 11ebe9eb66fae1f21e99f33616aca11206f96c66..5fdbeddfca443cf465d0f2f76c6b6c6d56575037 100644 (file)
@@ -67,6 +67,7 @@ endif()
 option(LLAMA_ACCELERATE             "llama: enable Accelerate framework"                    ON)
 option(LLAMA_OPENBLAS               "llama: use OpenBLAS"                                   OFF)
 option(LLAMA_CUBLAS                 "llama: use cuBLAS"                                     OFF)
+option(LLAMA_CLBLAST                "llama: use CLBlast"                                    OFF)
 
 option(LLAMA_BUILD_TESTS            "llama: build tests"    ${LLAMA_STANDALONE})
 option(LLAMA_BUILD_EXAMPLES         "llama: build examples" ${LLAMA_STANDALONE})
@@ -168,6 +169,21 @@ if (LLAMA_CUBLAS)
     endif()
 endif()
 
+if (LLAMA_CLBLAST)
+    find_package(CLBlast)
+    if (CLBlast_FOUND)
+        message(STATUS "CLBlast found")
+
+        set(GGML_OPENCL_SOURCES ggml-opencl.c ggml-opencl.h)
+
+        add_compile_definitions(GGML_USE_CLBLAST)
+
+        set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast)
+    else()
+        message(WARNING "CLBlast not found")
+    endif()
+endif()
+
 if (LLAMA_ALL_WARNINGS)
     if (NOT MSVC)
         set(c_flags
@@ -307,7 +323,8 @@ endif()
 add_library(ggml OBJECT
             ggml.c
             ggml.h
-            ${GGML_CUDA_SOURCES})
+            ${GGML_CUDA_SOURCES}
+            ${GGML_OPENCL_SOURCES})
 
 target_include_directories(ggml PUBLIC .)
 target_compile_features(ggml PUBLIC c_std_11) # don't bump
index f7c8dbfdc64acd987297e5ee8b55114648f9fcce..0715e857bc34663da2bf0e7d40a71fb8041f9ac0 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -105,14 +105,21 @@ ifdef LLAMA_OPENBLAS
        LDFLAGS += -lopenblas
 endif
 ifdef LLAMA_CUBLAS
-       CFLAGS    += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
-       LDFLAGS   += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
+       CFLAGS    += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
+       LDFLAGS   += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
        OBJS      += ggml-cuda.o
        NVCC      = nvcc
        NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
 ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
        $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
 endif
+ifdef LLAMA_CLBLAST
+       CFLAGS  += -DGGML_USE_CLBLAST
+       LDFLAGS += -lclblast -lOpenCL
+       OBJS    += ggml-opencl.o
+ggml-opencl.o: ggml-opencl.c ggml-opencl.h
+       $(CC) $(CFLAGS) -c $< -o $@
+endif
 ifdef LLAMA_GPROF
        CFLAGS   += -pg
        CXXFLAGS += -pg
diff --git a/ggml-opencl-dequant.cl b/ggml-opencl-dequant.cl
new file mode 100644 (file)
index 0000000..191b2e5
--- /dev/null
@@ -0,0 +1,84 @@
+#define MULTILINE_QUOTE(...) #__VA_ARGS__
+const char * clblast_dequant = MULTILINE_QUOTE(
+
+struct block_q4_0
+{
+    float d;
+    uchar qs[16];
+};
+
+__kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) {
+    const uint i = get_global_id(0) / 32;
+    const uint l = get_local_id(0);
+
+    const float d = blocks[i].d;
+
+    const uchar vi = blocks[i].qs[l];
+
+    const uint index = i*32 + l*2;
+    result[index + 0] = ((vi & 0xf) - 8)*d;
+    result[index + 1] = ((vi >> 4) - 8)*d;
+}
+
+struct block_q4_1
+{
+    float d;
+    float m;
+    uchar qs[16];
+};
+
+__kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) {
+    const uint i = get_global_id(0) / 32;
+    const uint l = get_local_id(0);
+
+    const float d = blocks[i].d;
+    const float m = blocks[i].m;
+
+    const uchar vi = blocks[i].qs[l];
+
+    const uint index = i*32 + l*2;
+    result[index + 0] = (vi & 0xf) * d + m;
+    result[index + 1] = (vi >> 4) * d + m;
+}
+
+struct block_q4_2
+{
+    ushort d;
+    uchar qs[8];
+};
+
+__kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global float* result) {
+    const uint i = get_global_id(0) / 16;
+    const uint l = get_local_id(0);
+
+    const float d = vload_half(0, (__global half*) &blocks[i].d);;
+
+    const uchar vi = blocks[i].qs[l];
+
+    const uint index = i*16 + l*2;
+    result[index + 0] = ((vi & 0xf) - 8)*d;
+    result[index + 1] = ((vi >> 4) - 8)*d;
+}
+
+struct block_q4_3
+{
+    ushort d;
+    ushort m;
+    uchar qs[8];
+};
+
+__kernel void dequantize_row_q4_3(__global struct block_q4_3* blocks, __global float* result) {
+    const uint i = get_global_id(0) / 16;
+    const uint l = get_local_id(0);
+
+    const float d = vload_half(0, (__global half*) &(blocks[i].d));
+    const float m = vload_half(0, (__global half*) &(blocks[i].m));
+
+    const uchar vi = blocks[i].qs[l];
+
+    const uint index = i*16 + l*2;
+    result[index + 0] = (vi & 0xf) * d + m;
+    result[index + 1] = (vi >> 4) * d + m;
+}
+
+);
diff --git a/ggml-opencl.c b/ggml-opencl.c
new file mode 100644 (file)
index 0000000..1d68f19
--- /dev/null
@@ -0,0 +1,216 @@
+#include "ggml-opencl.h"
+
+#define CL_TARGET_OPENCL_VERSION 110
+#include <clblast_c.h>
+
+#include <stdio.h>
+#include <string.h>
+
+#include "ggml.h"
+
+#include "ggml-opencl-dequant.cl"
+
+#define CL_CHECK(err, name)                                                                     \
+    do {                                                                                        \
+        cl_int err_ = (err);                                                                    \
+        if (err_ != CL_SUCCESS) {                                                               \
+            fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__);   \
+            exit(1);                                                                            \
+        }                                                                                       \
+    } while (0)
+
+static cl_platform_id platform;
+static cl_device_id device;
+static cl_context context;
+static cl_command_queue queue;
+static cl_program program;
+static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q4_3;
+static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
+static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
+
+static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) {
+    cl_program p;
+    char *program_log;
+    size_t program_size, log_size;
+    int err;
+
+    program_size = strlen(program_buffer);
+
+    p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err);
+    if(err < 0) {
+        fprintf(stderr, "OpenCL error creating program");
+        exit(1);
+    }
+
+    err = clBuildProgram(p, 0, NULL, NULL, NULL, NULL);
+    if(err < 0) {
+
+        clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
+        program_log = (char*) malloc(log_size + 1);
+        program_log[log_size] = '\0';
+        clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL);
+        printf("%s\n", program_log);
+        free(program_log);
+        exit(1);
+    }
+
+    return p;
+}
+
+void ggml_cl_init(void) {
+    cl_int err = 0;
+    char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM");
+    char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE");
+    int plat_num = (GGML_CLBLAST_PLATFORM == NULL ? 0 : atoi(GGML_CLBLAST_PLATFORM));
+    int dev_num = (GGML_CLBLAST_DEVICE == NULL ? 0 : atoi(GGML_CLBLAST_DEVICE));
+    printf("\nInitializing CLBlast (First Run)...");
+    printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
+    cl_uint num_platforms;
+    clGetPlatformIDs(0, NULL, &num_platforms);
+    cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
+    clGetPlatformIDs(num_platforms, platforms, NULL);
+    platform = platforms[plat_num];
+    char platform_buffer[1024];
+    clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
+    cl_uint num_devices;
+    clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
+    cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
+    clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
+    device = devices[dev_num];
+    char device_buffer[1024];
+    clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
+    printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
+    context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
+    CL_CHECK(err, "clCreateContext");
+    queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err);
+    CL_CHECK(err, "clCreateCommandQueue");
+
+    free(platforms);
+    free(devices);
+
+    program = build_program_from_source(context, device, clblast_dequant);
+
+    // Prepare dequantize kernels
+    kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err);
+    CL_CHECK(err, "clCreateKernel");
+    kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err);
+    CL_CHECK(err, "clCreateKernel");
+    kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err);
+    CL_CHECK(err, "clCreateKernel");
+    kernel_q4_3 = clCreateKernel(program, "dequantize_row_q4_3", &err);
+    CL_CHECK(err, "clCreateKernel");
+}
+
+static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
+    if (req_size <= *cur_size) {
+        return;
+    }
+
+    // Reallocate buffer with enough space
+    if (*cur_size > 0) {
+        clReleaseMemObject(*buf);
+    }
+    cl_int err;
+    *buf = clCreateBuffer(context, flags, req_size, NULL, &err);
+    *cur_size = req_size;
+    CL_CHECK(err, "clCreateBuffer");
+}
+
+void ggml_cl_sgemm_wrapper(
+        const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b,
+        const int m, const int n, const int k,
+        const float alpha, const void *host_a, const int lda,
+        const float *host_b, const int ldb, const float beta,
+        float *host_c, const int ldc, const int btype) {
+    cl_int err = 0;
+
+    cl_kernel kernel;
+    size_t global = n * k, local, size_qb;
+    bool dequant;
+
+    switch (btype) {
+    case GGML_TYPE_F32:
+        dequant = false;
+        break;
+    case GGML_TYPE_Q4_0:
+        dequant = true;
+        kernel = kernel_q4_0;
+        local = 16;
+        size_qb = global * (sizeof(float) + local) / 32;
+        break;
+    case GGML_TYPE_Q4_1:
+        dequant = true;
+        kernel = kernel_q4_1;
+        local = 16;
+        size_qb = global * (sizeof(float) * 2 + local) / 32;
+        break;
+    case GGML_TYPE_Q4_2:
+        dequant = true;
+        kernel = kernel_q4_2;
+        local = 8;
+        size_qb = global * (sizeof(short) + local) / 16;
+        break;
+    case GGML_TYPE_Q4_3:
+        dequant = true;
+        kernel = kernel_q4_3;
+        local = 8;
+        size_qb = global * (sizeof(short) * 2 + local) / 16;
+        break;
+    default:
+        fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
+        abort();
+    }
+
+    const size_t size_a =  m * k * sizeof(float);
+    const size_t size_b =  n * k * sizeof(float);
+    const size_t size_c =  m * n * sizeof(float);
+
+    // Prepare buffers
+    ggml_cl_malloc(size_a, &cl_size_a, CL_MEM_READ_ONLY, &cl_buffer_a);
+    if (dequant) {
+        ggml_cl_malloc(size_qb, &cl_size_qb, CL_MEM_READ_ONLY, &cl_buffer_qb);
+    }
+    ggml_cl_malloc(size_b, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b);
+    ggml_cl_malloc(size_c, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c);
+
+    cl_event ev_a, ev_qb, ev_b;
+
+    if (dequant) {
+        err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb);
+        err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b);
+        CL_CHECK(err, "clSetKernelArg");
+        clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb);
+    } else {
+        clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b);
+    }
+
+    clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a);
+    if (dequant) {
+        err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b);
+        CL_CHECK(err, "clEnqueueNDRangeKernel");
+        clReleaseEvent(ev_qb);
+    }
+    clWaitForEvents(1, &ev_a);
+    clWaitForEvents(1, &ev_b);
+    clReleaseEvent(ev_a);
+    clReleaseEvent(ev_b);
+
+    cl_event ev_sgemm;
+    CLBlastSgemm((CLBlastLayout)order,
+                 (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
+                 m, n, k,
+                 alpha,
+                 cl_buffer_a, 0, lda,
+                 cl_buffer_b, 0, ldb,
+                 beta,
+                 cl_buffer_c, 0, ldc,
+                 &queue, &ev_sgemm);
+
+    cl_event ev_c;
+    clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c);
+
+    // Wait for completion
+    clWaitForEvents(1, &ev_c);
+    clReleaseEvent(ev_sgemm);
+    clReleaseEvent(ev_c);
+}
diff --git a/ggml-opencl.h b/ggml-opencl.h
new file mode 100644 (file)
index 0000000..7bcc603
--- /dev/null
@@ -0,0 +1,24 @@
+#pragma once
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+void ggml_cl_init(void);
+
+enum ggml_blas_order {
+    GGML_BLAS_ORDER_ROW_MAJOR = 101,
+    GGML_BLAS_ORDER_COLUMN_MAJOR = 102,
+};
+
+enum ggml_blas_op {
+    GGML_BLAS_OP_N = 111,
+    GGML_BLAS_OP_T = 112,
+    GGML_BLAS_OP_C = 113,
+};
+
+void ggml_cl_sgemm_wrapper(const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/ggml.c b/ggml.c
index 1fbf2955d67308b8c2f25ab72aef6cf57b8c637a..33fb1681eaec41f42cb073aa3318d896cc81ce8b 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -149,6 +149,8 @@ inline static void* ggml_aligned_malloc(size_t size) {
 #include <cblas.h>
 #elif defined(GGML_USE_CUBLAS)
 #include "ggml-cuda.h"
+#elif defined(GGML_USE_CLBLAST)
+#include "ggml-opencl.h"
 #endif
 
 #undef MIN
@@ -4363,6 +4365,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         // initialize cuBLAS
         #if defined(GGML_USE_CUBLAS)
         ggml_init_cublas();
+        #elif defined(GGML_USE_CLBLAST)
+        ggml_cl_init();
         #endif
 
         is_first_call = false;
@@ -8104,7 +8108,7 @@ static void ggml_compute_forward_rms_norm(
 
 // ggml_compute_forward_mul_mat
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
 // helper function to determine if it is better to use BLAS or not
 // for large matrices, BLAS is faster
 static bool ggml_compute_forward_mul_mat_use_blas(
@@ -8129,6 +8133,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
 
     return false;
 }
+
 #endif
 
 static void ggml_compute_forward_mul_mat_f32(
@@ -8144,7 +8149,7 @@ static void ggml_compute_forward_mul_mat_f32(
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     const int64_t ne10 = src1->ne[0];
 #endif
     const int64_t ne11 = src1->ne[1];
@@ -8201,7 +8206,7 @@ static void ggml_compute_forward_mul_mat_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -8250,8 +8255,15 @@ static void ggml_compute_forward_mul_mat_f32(
 
                 // copy data to host
                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
-#else
+#elif defined(GGML_USE_CLBLAST)
                 // zT = y * xT
+                ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01,
+                        GGML_TYPE_F32);
+#else
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
@@ -8395,7 +8407,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         GGML_ASSERT(nb10 == sizeof(float));
 
@@ -8472,6 +8484,19 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 
                 // copy data to host
                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
+#elif defined(GGML_USE_CLBLAST)
+                const float * x = wdata;
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // zT = y * xT
+                ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01,
+                        GGML_TYPE_F32);
 #else
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@@ -8646,7 +8671,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -8698,7 +8723,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
         else {
             GGML_ASSERT(false);
         }
-#else
+#elif !defined(GGML_USE_CLBLAST)
         float * const wdata = params->wdata;
         dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
 #endif
@@ -8717,6 +8742,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
                 dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
                 CUDA_CHECK(cudaGetLastError());
+#elif defined(GGML_USE_CLBLAST)
+                const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
 #else
                 {
                     size_t id = 0;
@@ -8743,8 +8770,15 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
                 // copy data to host
                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
-#else
+#elif defined(GGML_USE_CLBLAST)
                 // zT = y * xT
+                ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01,
+                        type);
+#else
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
@@ -11583,7 +11617,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         size_t cur = 0;
 
                         if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1; // TODO: this actually is doing nothing
                                                    //       the threads are still spinning
@@ -11600,7 +11634,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
                             cur = 0;
                         } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1;
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
@@ -13100,7 +13134,7 @@ int ggml_cpu_has_wasm_simd(void) {
 }
 
 int ggml_cpu_has_blas(void) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     return 1;
 #else
     return 0;
@@ -13115,6 +13149,18 @@ int ggml_cpu_has_cublas(void) {
 #endif
 }
 
+int ggml_cpu_has_clblast(void) {
+#if defined(GGML_USE_CLBLAST)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_gpublas(void) {
+    return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
+}
+
 int ggml_cpu_has_sse3(void) {
 #if defined(__SSE3__)
     return 1;
diff --git a/ggml.h b/ggml.h
index d9d3d214e84e70f827d59b11ec2a35d47fb9b26f..1bbe2db93f5d1e6f79e29e03d8d293a1bdada76f 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -858,10 +858,11 @@ extern "C" {
     GGML_API int ggml_cpu_has_wasm_simd  (void);
     GGML_API int ggml_cpu_has_blas       (void);
     GGML_API int ggml_cpu_has_cublas     (void);
+    GGML_API int ggml_cpu_has_clblast    (void);
+    GGML_API int ggml_cpu_has_gpublas    (void);
     GGML_API int ggml_cpu_has_sse3       (void);
     GGML_API int ggml_cpu_has_vsx        (void);
 
-
     //
     // Internal types and functions exposed for tests and benchmarks
     //
index 28a74b514b852488a996e3aec0d4142da027879a..bfebf14bfde3ffb57045952c3788cf3d883618d1 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1085,7 +1085,7 @@ static bool llama_eval_internal(
     // for big prompts, if BLAS is enabled, it is better to use only one thread
     // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
     ggml_cgraph gf = {};
-    gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_cublas() ? 1 : n_threads;
+    gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, tokens, N*ggml_element_size(embd));