]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : sync ggml (clBLAST + tensor names)
authorGeorgi Gerganov <redacted>
Tue, 2 May 2023 18:23:54 +0000 (21:23 +0300)
committerGeorgi Gerganov <redacted>
Tue, 2 May 2023 18:24:18 +0000 (21:24 +0300)
extra/sync-ggml.sh
ggml-cuda.cu
ggml-cuda.h
ggml-opencl.c [new file with mode: 0644]
ggml-opencl.h [new file with mode: 0644]
ggml.c
ggml.h

index 2fa392ffed373e81d7a5b1bbf05b3229718c7b52..8629860d44af9971d3e96835ca296173e076faaf 100755 (executable)
@@ -1,8 +1,10 @@
 #!/bin/bash
 
 cp -rpv ../ggml/src/ggml.c               ./ggml.c
-cp -rpv ../ggml/src/ggml-cuda.cu         ./ggml-cuda.cu
 cp -rpv ../ggml/src/ggml-cuda.h          ./ggml-cuda.h
+cp -rpv ../ggml/src/ggml-cuda.cu         ./ggml-cuda.cu
+cp -rpv ../ggml/src/ggml-opencl.h        ./ggml-opencl.h
+cp -rpv ../ggml/src/ggml-opencl.c        ./ggml-opencl.c
 cp -rpv ../ggml/include/ggml/ggml.h      ./ggml.h
 cp -rpv ../ggml/examples/common.h        ./examples/common.h
 cp -rpv ../ggml/examples/common.cpp      ./examples/common.cpp
index 5a2701cfeef68696730b7c3c067fb41a60ad221d..e8a1e77cb06fc0b947aec8de0de21c0988666f6f 100644 (file)
@@ -1,11 +1,38 @@
+#include <cstddef>
+#include <cstdint>
 #include <stdint.h>
 #include <stdio.h>
-#include <cuda_fp16.h>
 #include <atomic>
-#include "ggml-cuda.h"
 
-typedef uint16_t ggml_fp16_t;
-static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
+#include <cuda_runtime.h>
+#include <cublas_v2.h>
+#include <cuda_fp16.h>
+
+#include "ggml-cuda.h"
+#include "ggml.h"
+
+static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
+
+#define CUDA_CHECK(err)                                                                 \
+    do {                                                                                \
+        cudaError_t err_ = (err);                                                       \
+        if (err_ != cudaSuccess) {                                                      \
+            fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \
+                cudaGetErrorString(err_));                                              \
+            exit(1);                                                                    \
+        }                                                                               \
+    } while (0)
+
+#define CUBLAS_CHECK(err)                                                               \
+    do {                                                                                \
+        cublasStatus_t err_ = (err);                                                    \
+        if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
+            fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
+            exit(1);                                                                    \
+        }                                                                               \
+    } while (0)
+
+typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
 
 #define QK4_0 32
 typedef struct {
@@ -24,14 +51,14 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 b
 
 #define QK4_2 16
 typedef struct {
-    __half  d;              // delta
+    half  d;                // delta
     uint8_t qs[QK4_2 / 2];  // nibbles / quants
 } block_q4_2;
 static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
 
 #define QK5_0 32
 typedef struct {
-    __half d;               // delta
+    half d;                 // delta
     uint8_t qh[4];          // 5-th bit of quants
     uint8_t qs[QK5_0 / 2];  // nibbles / quants
 } block_q5_0;
@@ -39,9 +66,9 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
 
 #define QK5_1 32
 typedef struct {
-    __half d;               // delta
-    __half m;               // min
-    uint32_t qh;            // 5-th bit of quants
+    half d;                 // delta
+    half m;                 // min
+    uint8_t qh[4];          // 5-th bit of quants
     uint8_t qs[QK5_1 / 2];  // nibbles / quants
 } block_q5_1;
 static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
@@ -162,7 +189,8 @@ static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
 
     const uint8_t * pp = x[i].qs;
 
-    const uint32_t qh = x[i].qh;
+    uint32_t qh;
+    memcpy(&qh, x[i].qh, sizeof(qh));
 
     for (int l = 0; l < QK5_1; l += 2) {
         const uint8_t vi = pp[l/2];
@@ -197,37 +225,50 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
     }
 }
 
-void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
     const int nb = k / QK4_0;
     dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
 }
 
-void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
     const int nb = k / QK4_1;
     dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
 }
 
-void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
     const int nb = k / QK4_2;
     dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
 }
 
-void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
     const int nb = k / QK5_0;
     dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
 }
 
-void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
     const int nb = k / QK5_1;
     dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
 }
 
-void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
     const int nb = k / QK8_0;
     dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
 }
 
-dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
+// TODO: optimize
+static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
+    const half * x = (const half *) vx;
+
+    const int i = blockIdx.x;
+
+    y[i] = __half2float(x[i]);
+}
+
+static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) {
+    convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
+}
+
+static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_Q4_0:
             return dequantize_row_q4_0_cuda;
@@ -241,6 +282,8 @@ dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
             return dequantize_row_q5_1_cuda;
         case GGML_TYPE_Q8_0:
             return dequantize_row_q8_0_cuda;
+        case GGML_TYPE_F16:
+            return convert_fp16_to_fp32_cuda;
         default:
             return nullptr;
     }
@@ -271,7 +314,7 @@ struct cuda_buffer {
 static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
 static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
 
-void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
+static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
 
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
@@ -290,7 +333,7 @@ void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
     return ptr;
 }
 
-void ggml_cuda_pool_free(void * ptr, size_t size) {
+static void ggml_cuda_pool_free(void * ptr, size_t size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
 
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
@@ -305,28 +348,55 @@ void ggml_cuda_pool_free(void * ptr, size_t size) {
     CUDA_CHECK(cudaFree(ptr));
 }
 
-cublasHandle_t g_cublasH = nullptr;
-cudaStream_t g_cudaStream = nullptr;
-cudaStream_t g_cudaStream2 = nullptr;
-cudaEvent_t g_cudaEvent = nullptr;
+#define GGML_CUDA_MAX_STREAMS 8
+#define GGML_CUDA_MAX_EVENTS 64
+static cublasHandle_t g_cublasH = nullptr;
+static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr };
+static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr };
+static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr };
 
 void ggml_init_cublas() {
     if (g_cublasH == nullptr) {
-        // create cublas handle, bind a stream
-        CUBLAS_CHECK(cublasCreate(&g_cublasH));
-        CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
-        CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
+        // create streams
+        for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
+            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking));
+            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking));
+        }
+        // create events
+        for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
+            CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming));
+        }
 
-        // create additional stream and event for synchronization
-        CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
-        CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
+        // create cublas handle
+        CUBLAS_CHECK(cublasCreate(&g_cublasH));
+        CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
 
         // configure logging to stdout
-        // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
+        // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
     }
 }
 
-cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
+void * ggml_cuda_host_malloc(size_t size) {
+    if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
+        return nullptr;
+    }
+
+    void * ptr = nullptr;
+    cudaError_t err = cudaMallocHost((void **) &ptr, size);
+    if (err != cudaSuccess) {
+        fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
+            size/1024.0/1024.0, cudaGetErrorString(err));
+        return nullptr;
+    }
+
+    return ptr;
+}
+
+void ggml_cuda_host_free(void * ptr) {
+    CUDA_CHECK(cudaFreeHost(ptr));
+}
+
+static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
     const uint64_t ne0 = src->ne[0];
     const uint64_t ne1 = src->ne[1];
     const uint64_t nb0 = src->nb[0];
@@ -354,12 +424,293 @@ cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src,
     }
 }
 
-void * ggml_cuda_host_malloc(size_t size) {
-    void * ptr;
-    CUDA_CHECK(cudaMallocHost((void **) &ptr, size));
-    return ptr;
+static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+    const int x_ne = ne01 * ne00;
+    const int y_ne = ne11 * ne10;
+    const int d_ne = ne11 * ne01;
+    const int n_mm = ne03 * ne02;
+
+    size_t x_size, y_size, d_size;
+    float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
+    float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
+    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            int i = i03*ne02 + i02;
+            cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
+
+            float * c_X = d_X + i * x_ne;
+            float * c_Y = d_Y + i * y_ne;
+            float * c_D = d_D + i * d_ne;
+
+            // copy data to device
+            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
+            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+
+            // compute
+            CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
+            CUBLAS_CHECK(
+                cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                        ne01, ne11, ne10,
+                        &alpha, c_X, ne00,
+                                c_Y, ne10,
+                        &beta,  c_D, ne01));
+
+            // copy dst to host
+            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+        }
+    }
+
+    CUDA_CHECK(cudaDeviceSynchronize());
+    ggml_cuda_pool_free(d_X, x_size);
+    ggml_cuda_pool_free(d_Y, y_size);
+    ggml_cuda_pool_free(d_D, d_size);
 }
 
-void ggml_cuda_host_free(void * ptr) {
-    CUDA_CHECK(cudaFreeHost(ptr));
+static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+    const int x_ne = ne01 * ne00;
+    const int y_ne = ne11 * ne10;
+    const int d_ne = ne11 * ne01;
+    const int n_mm = ne03 * ne02;
+
+    size_t x_size, y_size, d_size;
+    half  * d_X =  (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
+    half  * d_Y =  (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
+    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
+
+    bool src1_cont_rows = nb10 == sizeof(float);
+    bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            int i = i03*ne02 + i02;
+            cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
+
+            half  * c_X = d_X + i * x_ne;
+            half  * c_Y = d_Y + i * y_ne;
+            float * c_D = d_D + i * d_ne;
+
+            // copy src0 to device
+            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
+
+            // convert src1 to fp16
+            // TODO: use multiple threads
+            ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
+            char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
+            if (src1_cont_rows) {
+                if (src1_cont_cols) {
+                    ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
+                }
+                else {
+                    for (int64_t i01 = 0; i01 < ne11; i01++) {
+                        ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
+                    }
+                }
+            }
+            else {
+                for (int64_t i01 = 0; i01 < ne11; i01++) {
+                    for (int64_t i00 = 0; i00 < ne10; i00++) {
+                        // very slow due to no inlining
+                        tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
+                    }
+                }
+            }
+
+            // copy src1 to device
+            CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+
+            // compute
+            CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
+            CUBLAS_CHECK(
+                cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                        ne01, ne11, ne10,
+                        &alpha, c_X, CUDA_R_16F, ne00,
+                                c_Y, CUDA_R_16F, ne10,
+                        &beta,  c_D, CUDA_R_32F, ne01,
+                        CUBLAS_COMPUTE_32F_FAST_16F,
+                        CUBLAS_GEMM_DEFAULT));
+
+            // copy dst to host
+            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+        }
+    }
+
+    CUDA_CHECK(cudaDeviceSynchronize());
+    ggml_cuda_pool_free(d_X, x_size);
+    ggml_cuda_pool_free(d_Y, y_size);
+    ggml_cuda_pool_free(d_D, d_size);
+}
+
+static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+    const ggml_type type = src0->type;
+
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+    const int x_ne = ne01 * ne00;
+    const int y_ne = ne11 * ne10;
+    const int d_ne = ne11 * ne01;
+    const int n_mm = ne03 * ne02;
+    const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
+
+    size_t x_size, y_size, d_size, q_size;
+    float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
+    float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
+    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
+    char  * d_Q = (char  *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
+
+    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
+    GGML_ASSERT(to_fp32_cuda != nullptr);
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            int i = i03*ne02 + i02;
+            cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
+            cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
+            cudaEvent_t  cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
+
+            float * c_X = d_X + i * x_ne;
+            float * c_Y = d_Y + i * y_ne;
+            float * c_D = d_D + i * d_ne;
+            char  * c_Q = d_Q + i * q_sz;
+
+            // copy src0 and convert to fp32 on device
+            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
+            to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
+            CUDA_CHECK(cudaGetLastError());
+            CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+
+            // copy src1 to device
+            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+
+            // wait for conversion
+            CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+
+            // compute
+            CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
+            CUBLAS_CHECK(
+                cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                        ne01, ne11, ne10,
+                        &alpha, c_X, ne00,
+                                c_Y, ne10,
+                        &beta,  c_D, ne01));
+
+            // copy dst to host
+            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+        }
+    }
+
+    CUDA_CHECK(cudaDeviceSynchronize());
+    ggml_cuda_pool_free(d_X, x_size);
+    ggml_cuda_pool_free(d_Y, y_size);
+    ggml_cuda_pool_free(d_D, d_size);
+    ggml_cuda_pool_free(d_Q, q_size);
+}
+
+bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    const int64_t ne10 = src1->ne[0];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+
+    // TODO: find the optimal values for these
+    if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
+        src1->type == GGML_TYPE_F32 &&
+        dst->type == GGML_TYPE_F32 &&
+        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
+
+        return true;
+    }
+
+    return false;
+}
+
+bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
+    size_t src0_sz = ggml_nbytes(src0);
+    size_t src1_sz = ggml_nbytes(src1);
+
+    // mul_mat_q: src0 is converted to fp32 on device
+    size_t mul_mat_q_transfer = src0_sz + src1_sz;
+
+    // mul_mat_f16: src1 is converted to fp16 on cpu
+    size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
+
+    // choose the smaller one to transfer to the device
+    // TODO: this is not always the best choice due to the overhead of converting to fp16
+    return mul_mat_f16_transfer < mul_mat_q_transfer;
+}
+
+void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
+    GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
+
+    if (src0->type == GGML_TYPE_F32) {
+        ggml_cuda_mul_mat_f32(src0, src1, dst);
+    }
+    else if (src0->type == GGML_TYPE_F16) {
+        if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
+            ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
+        }
+        else {
+            ggml_cuda_mul_mat_q_f32(src0, src1, dst);
+        }
+    }
+    else if (ggml_is_quantized(src0->type)) {
+        ggml_cuda_mul_mat_q_f32(src0, src1, dst);
+    }
+    else {
+        GGML_ASSERT(false);
+    }
+}
+
+size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
+        return ggml_nelements(src1) * sizeof(ggml_fp16_t);
+    }
+    else {
+        return 0;
+    }
 }
index 36782d9e796b7ed873a036cc3cfb542ead3e1116..f7d6a8bc1842ac2ba9fb350215d5e7ee2e8f18ac 100644 (file)
@@ -1,54 +1,19 @@
-#include <cublas_v2.h>
-#include <cuda_runtime.h>
 #include "ggml.h"
 
 #ifdef  __cplusplus
 extern "C" {
 #endif
 
-#define CUDA_CHECK(err)                                                                 \
-    do {                                                                                \
-        cudaError_t err_ = (err);                                                       \
-        if (err_ != cudaSuccess) {                                                      \
-            fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \
-                cudaGetErrorString(err_));                                              \
-            exit(1);                                                                    \
-        }                                                                               \
-    } while (0)
-
-#define CUBLAS_CHECK(err)                                                               \
-    do {                                                                                \
-        cublasStatus_t err_ = (err);                                                    \
-        if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
-            fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
-            exit(1);                                                                    \
-        }                                                                               \
-    } while (0)
+void   ggml_init_cublas(void);
 
-extern cublasHandle_t g_cublasH;
-extern cudaStream_t g_cudaStream;
-extern cudaStream_t g_cudaStream2;
-extern cudaEvent_t g_cudaEvent;
+bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
+size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
+void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
 
-void   ggml_init_cublas(void);
+// TODO: export these with GGML_API
 void * ggml_cuda_host_malloc(size_t size);
 void   ggml_cuda_host_free(void * ptr);
 
-void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
-void   ggml_cuda_pool_free(void * ptr, size_t size);
-
-void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
-void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
-void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
-void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
-void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
-void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
-
-cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
-
-typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
-dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type);
-
 #ifdef  __cplusplus
 }
 #endif
diff --git a/ggml-opencl.c b/ggml-opencl.c
new file mode 100644 (file)
index 0000000..4389eca
--- /dev/null
@@ -0,0 +1,398 @@
+#include "ggml-opencl.h"
+
+#define CL_TARGET_OPENCL_VERSION 110
+#include <clblast_c.h>
+
+#include <stdlib.h>
+#include <stdio.h>
+#include <string.h>
+
+#include "ggml.h"
+
+#define MULTILINE_QUOTE(...) #__VA_ARGS__
+const char * clblast_dequant = MULTILINE_QUOTE(
+
+struct block_q4_0
+{
+    float d;
+    uchar qs[16];
+};
+
+__kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) {
+    const uint i = get_global_id(0) / 32;
+    const uint l = get_local_id(0);
+
+    const float d = blocks[i].d;
+
+    const uchar vi = blocks[i].qs[l];
+
+    const uint index = i*32 + l*2;
+    result[index + 0] = ((vi & 0xf) - 8)*d;
+    result[index + 1] = ((vi >> 4) - 8)*d;
+}
+
+struct block_q4_1
+{
+    float d;
+    float m;
+    uchar qs[16];
+};
+
+__kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) {
+    const uint i = get_global_id(0) / 32;
+    const uint l = get_local_id(0);
+
+    const float d = blocks[i].d;
+    const float m = blocks[i].m;
+
+    const uchar vi = blocks[i].qs[l];
+
+    const uint index = i*32 + l*2;
+    result[index + 0] = (vi & 0xf) * d + m;
+    result[index + 1] = (vi >> 4) * d + m;
+}
+
+struct block_q4_2
+{
+    ushort d;
+    uchar qs[8];
+};
+
+__kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global float* result) {
+    const uint i = get_global_id(0) / 16;
+    const uint l = get_local_id(0);
+
+    const float d = vload_half(0, (__global half*) &blocks[i].d);
+
+    const uchar vi = blocks[i].qs[l];
+
+    const uint index = i*16 + l*2;
+    result[index + 0] = ((vi & 0xf) - 8)*d;
+    result[index + 1] = ((vi >> 4) - 8)*d;
+}
+
+
+struct block_q5_0
+{
+    float d;
+    uint qh;
+    uchar qs[16];
+};
+
+__kernel void dequantize_row_q5_0(__global struct block_q5_0* blocks, __global float* result) {
+    const uint i = get_global_id(0) / 32;
+    const uint l = get_local_id(0);
+
+    const float d = blocks[i].d;
+
+    const uchar vi = blocks[i].qs[l];
+
+    const uint l2 = l * 2;
+
+    const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4;
+    const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4;
+
+    const uint index = i*32 + l2;
+    result[index + 0] = (((vi & 0xf) | vh0) - 16)*d;
+    result[index + 1] = (((vi >>  4) | vh1) - 16)*d;
+}
+
+struct block_q5_1
+{
+    ushort d;
+    ushort m;
+    uint qh;
+    uchar qs[16];
+};
+
+__kernel void dequantize_row_q5_1(__global struct block_q5_1* blocks, __global float* result) {
+    const uint i = get_global_id(0) / 32;
+    const uint l = get_local_id(0);
+
+    const float d = vload_half(0, (__global half*) &blocks[i].d);
+    const float m = vload_half(0, (__global half*) &blocks[i].m);
+
+    const uchar vi = blocks[i].qs[l];
+
+    const uint l2 = l * 2;
+
+    const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4;
+    const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4;
+
+    const uint index = i*32 + l2;
+    result[index + 0] = ((vi & 0xf) | vh0)*d + m;
+    result[index + 1] = ((vi >>  4) | vh1)*d + m;
+}
+
+struct block_q8_0
+{
+    float d;
+    char qs[32];
+};
+
+__kernel void dequantize_row_q8_0(__global struct block_q8_0* blocks, __global float* result) {
+    const uint i = get_global_id(0) / 32;
+    const uint l = get_local_id(0);
+
+    result[i*32 + l] = blocks[i].qs[l] * blocks[i].d;
+}
+
+);
+
+#define CL_CHECK(err, name)                                                                     \
+    do {                                                                                        \
+        cl_int err_ = (err);                                                                    \
+        if (err_ != CL_SUCCESS) {                                                               \
+            fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__);   \
+            exit(1);                                                                            \
+        }                                                                                       \
+    } while (0)
+
+#define QK5_0 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+
+
+typedef struct {
+    float d;                // delta
+    uint32_t qh;          // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2];  // nibbles / quants
+} cl_block_q5_0;
+
+static cl_platform_id platform;
+static cl_device_id device;
+static cl_context context;
+static cl_command_queue queue;
+static cl_program program;
+static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q5_0, kernel_q5_1, kernel_q8_0;
+static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
+static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
+
+static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) {
+    cl_program p;
+    char *program_log;
+    size_t program_size, log_size;
+    int err;
+
+    program_size = strlen(program_buffer);
+
+    p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err);
+    if(err < 0) {
+        fprintf(stderr, "OpenCL error creating program");
+        exit(1);
+    }
+
+    err = clBuildProgram(p, 0, NULL, NULL, NULL, NULL);
+    if(err < 0) {
+
+        clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
+        program_log = (char*) malloc(log_size + 1);
+        program_log[log_size] = '\0';
+        clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL);
+        printf("%s\n", program_log);
+        free(program_log);
+        exit(1);
+    }
+
+    return p;
+}
+
+void ggml_cl_init(void) {
+    cl_int err = 0;
+    char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM");
+    char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE");
+    int plat_num = (GGML_CLBLAST_PLATFORM == NULL ? 0 : atoi(GGML_CLBLAST_PLATFORM));
+    int dev_num = (GGML_CLBLAST_DEVICE == NULL ? 0 : atoi(GGML_CLBLAST_DEVICE));
+    printf("\nInitializing CLBlast (First Run)...");
+    printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
+    cl_uint num_platforms;
+    clGetPlatformIDs(0, NULL, &num_platforms);
+    cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
+    clGetPlatformIDs(num_platforms, platforms, NULL);
+    platform = platforms[plat_num];
+    char platform_buffer[1024];
+    clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
+    cl_uint num_devices;
+    clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
+    cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
+    clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
+    device = devices[dev_num];
+    char device_buffer[1024];
+    clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
+    printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
+    context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
+    CL_CHECK(err, "clCreateContext");
+    queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err);
+    CL_CHECK(err, "clCreateCommandQueue");
+
+    free(platforms);
+    free(devices);
+
+    program = build_program_from_source(context, device, clblast_dequant);
+
+    // Prepare dequantize kernels
+    kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err);
+    CL_CHECK(err, "clCreateKernel");
+    kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err);
+    CL_CHECK(err, "clCreateKernel");
+    kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err);
+    CL_CHECK(err, "clCreateKernel");
+    kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err);
+    CL_CHECK(err, "clCreateKernel");
+    kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err);
+    CL_CHECK(err, "clCreateKernel");
+    kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err);
+    CL_CHECK(err, "clCreateKernel");
+}
+
+static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
+    if (req_size <= *cur_size) {
+        return;
+    }
+
+    // Reallocate buffer with enough space
+    if (*cur_size > 0) {
+        clReleaseMemObject(*buf);
+    }
+    cl_int err;
+    *buf = clCreateBuffer(context, flags, req_size, NULL, &err);
+    *cur_size = req_size;
+    CL_CHECK(err, "clCreateBuffer");
+}
+
+void ggml_cl_sgemm_wrapper(
+        const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b,
+        const int m, const int n, const int k,
+        const float alpha, const void *host_a, const int lda,
+        const float *host_b, const int ldb, const float beta,
+        float *host_c, const int ldc, const int btype) {
+    cl_int err = 0;
+
+    cl_kernel kernel;
+    size_t global = n * k, local, size_qb;
+    bool dequant;
+    cl_block_q5_0* cl_host_b;
+
+    switch (btype) {
+    case GGML_TYPE_F32:
+        dequant = false;
+        break;
+    case GGML_TYPE_Q4_0:
+        dequant = true;
+        kernel = kernel_q4_0;
+        local = 16;
+        size_qb = global * (sizeof(float) + local) / 32;
+        break;
+    case GGML_TYPE_Q4_1:
+        dequant = true;
+        kernel = kernel_q4_1;
+        local = 16;
+        size_qb = global * (sizeof(float) * 2 + local) / 32;
+        break;
+    case GGML_TYPE_Q4_2:
+        dequant = true;
+        kernel = kernel_q4_2;
+        local = 8;
+        size_qb = global * (sizeof(ggml_fp16_t) + local) / 16;
+        break;
+    case GGML_TYPE_Q5_0:
+        dequant = true;
+        kernel = kernel_q5_0;
+        local = 16;
+        // For some reason OpenCL seems to be incapable of working with structs of size 22.
+        // 20 and 24 bytes are fine. Workaround to do the fp16 to fp32 step on CPU...
+        // TODO Find the reason, fix and remove workaround.
+        const block_q5_0* b = (const block_q5_0*) host_b;
+        cl_host_b = (cl_block_q5_0*) malloc(sizeof(cl_block_q5_0) * global / 32);
+        for (size_t i = 0; i < global / 32; i++) {
+            cl_host_b[i].d = ggml_fp16_to_fp32(b[i].d);
+            memcpy(&cl_host_b[i].qh, b[i].qh, sizeof(uint32_t));
+            memcpy(&cl_host_b[i].qs, b[i].qs, QK5_0 / 2);
+        }
+        host_b = (const float*) cl_host_b;
+        size_qb = global * (sizeof(float) + sizeof(uint32_t) + local) / 32;
+        break;
+    case GGML_TYPE_Q5_1:
+        dequant = true;
+        kernel = kernel_q5_1;
+        local = 16;
+        size_qb = global * (sizeof(ggml_fp16_t) * 2 + sizeof(uint32_t) + local) / 32;
+        break;
+    case GGML_TYPE_Q8_0:
+        dequant = true;
+        kernel = kernel_q8_0;
+        local = 32;
+        size_qb = global * (sizeof(float) + local) / 32;
+        break;
+    default:
+        fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
+        abort();
+    }
+
+    const size_t size_a =  m * k * sizeof(float);
+    const size_t size_b =  n * k * sizeof(float);
+    const size_t size_c =  m * n * sizeof(float);
+
+    // Prepare buffers
+    ggml_cl_malloc(size_a, &cl_size_a, CL_MEM_READ_ONLY, &cl_buffer_a);
+    if (dequant) {
+        ggml_cl_malloc(size_qb, &cl_size_qb, CL_MEM_READ_ONLY, &cl_buffer_qb);
+    }
+    ggml_cl_malloc(size_b, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b);
+    ggml_cl_malloc(size_c, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c);
+
+    cl_event ev_a, ev_qb, ev_b;
+
+    if (dequant) {
+        err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb);
+        err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b);
+        CL_CHECK(err, "clSetKernelArg");
+        err = clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb);
+        CL_CHECK(err, "clEnqueueWriteBuffer qb");
+    } else {
+        err = clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b);
+        CL_CHECK(err, "clEnqueueWriteBuffer b");
+    }
+
+    err = clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a);
+    CL_CHECK(err, "clEnqueueWriteBuffer a");
+    if (dequant) {
+        err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b);
+        CL_CHECK(err, "clEnqueueNDRangeKernel");
+        clReleaseEvent(ev_qb);
+    }
+    clWaitForEvents(1, &ev_a);
+    clWaitForEvents(1, &ev_b);
+    clReleaseEvent(ev_a);
+    clReleaseEvent(ev_b);
+
+    cl_event ev_sgemm;
+    CLBlastStatusCode status = CLBlastSgemm((CLBlastLayout)order,
+                                            (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
+                                            m, n, k,
+                                            alpha,
+                                            cl_buffer_a, 0, lda,
+                                            cl_buffer_b, 0, ldb,
+                                            beta,
+                                            cl_buffer_c, 0, ldc,
+                                            &queue, &ev_sgemm);
+
+    if (status != CLBlastSuccess) {
+        fprintf(stderr, "Error: CLBlast SGEMM %d\n", status);
+        abort();
+    }
+
+    cl_event ev_c;
+    clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c);
+
+    // Wait for completion
+    clWaitForEvents(1, &ev_c);
+    clReleaseEvent(ev_sgemm);
+    clReleaseEvent(ev_c);
+    if (btype == GGML_TYPE_Q5_0) {
+        free((void*) cl_host_b);
+    }
+}
diff --git a/ggml-opencl.h b/ggml-opencl.h
new file mode 100644 (file)
index 0000000..7bcc603
--- /dev/null
@@ -0,0 +1,24 @@
+#pragma once
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+void ggml_cl_init(void);
+
+enum ggml_blas_order {
+    GGML_BLAS_ORDER_ROW_MAJOR = 101,
+    GGML_BLAS_ORDER_COLUMN_MAJOR = 102,
+};
+
+enum ggml_blas_op {
+    GGML_BLAS_OP_N = 111,
+    GGML_BLAS_OP_T = 112,
+    GGML_BLAS_OP_C = 113,
+};
+
+void ggml_cl_sgemm_wrapper(const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/ggml.c b/ggml.c
index 8cc48344ea4bc7fac832965b3f8886015c84465a..91b3053dd23fe8be057e1adedb0ce79487afb29f 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -135,14 +135,6 @@ inline static void* ggml_aligned_malloc(size_t size) {
 #define UNUSED(x) (void)(x)
 #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
 
-#define GGML_ASSERT(x) \
-    do { \
-        if (!(x)) { \
-            fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
-            abort(); \
-        } \
-    } while (0)
-
 #if defined(GGML_USE_ACCELERATE)
 #include <Accelerate/Accelerate.h>
 #elif defined(GGML_USE_OPENBLAS)
@@ -370,6 +362,32 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
     return GGML_FP32_TO_FP16(x);
 }
 
+void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) {
+    for (size_t i = 0; i < n; i++) {
+        y[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+}
+
+void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
+    size_t i = 0;
+#if defined(__F16C__)
+    for (; i + 7 < n; i += 8) {
+        __m256 x_vec = _mm256_loadu_ps(x + i);
+        __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
+        _mm_storeu_si128((__m128i *)(y + i), y_vec);
+    }
+    for(; i + 3 < n; i += 4) {
+        __m128 x_vec = _mm_loadu_ps(x + i);
+        __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
+        _mm_storel_epi64((__m128i *)(y + i), y_vec);
+    }
+#endif
+    for (; i < n; i++) {
+        y[i] = GGML_FP32_TO_FP16(x[i]);
+    }
+}
+
+
 //
 // timing
 //
@@ -808,6 +826,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
         float max = 0.0f;
         float min = 0.0f;
 
+        vector float asrcv [8];
         vector float srcv [8];
         vector float maxv[8];
         vector float minv[8];
@@ -4325,12 +4344,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
             GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
         }
 
-        // initialize cuBLAS
-        #if defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_CUBLAS)
         ggml_init_cublas();
-        #elif defined(GGML_USE_CLBLAST)
+#elif defined(GGML_USE_CLBLAST)
         ggml_cl_init();
-        #endif
+#endif
 
         is_first_call = false;
     }
@@ -4411,7 +4429,7 @@ void ggml_free(struct ggml_context * ctx) {
 }
 
 size_t ggml_used_mem(const struct ggml_context * ctx) {
-    return ctx->objects_end->offs + ctx->objects_end->size;
+    return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
 }
 
 size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
@@ -4524,6 +4542,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
         /*.perf_cycles  =*/ 0,
         /*.perf_time_us =*/ 0,
         /*.data         =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
+        /*.name         =*/ { 0 },
         /*.pad          =*/ { 0 },
     };
 
@@ -4878,6 +4897,15 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
     return (float *)(tensor->data);
 }
 
+const char * ggml_get_name(const struct ggml_tensor * tensor) {
+    return tensor->name;
+}
+
+void ggml_set_name(struct ggml_tensor * tensor, const char * name) {
+    strncpy(tensor->name, name, sizeof(tensor->name));
+    tensor->name[sizeof(tensor->name) - 1] = '\0';
+}
+
 struct ggml_tensor * ggml_view_tensor(
         struct ggml_context * ctx,
         const struct ggml_tensor * src) {
@@ -5977,6 +6005,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
     //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
     struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
+    ggml_set_name(b, "n_past");
 
     result->op   = GGML_OP_DIAG_MASK_INF;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6034,6 +6063,7 @@ struct ggml_tensor * ggml_rope(
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_dims;
     ((int32_t *) b->data)[2] = mode;
+    ggml_set_name(b, "n_past, n_dims, mode");
 
     result->op   = GGML_OP_ROPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -8101,7 +8131,7 @@ static void ggml_compute_forward_rms_norm(
 
 // ggml_compute_forward_mul_mat
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
 // helper function to determine if it is better to use BLAS or not
 // for large matrices, BLAS is faster
 static bool ggml_compute_forward_mul_mat_use_blas(
@@ -8117,12 +8147,9 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     const int64_t ne1 = dst->ne[1];
 
     // TODO: find the optimal values for these
-    if (
-#if !defined(GGML_USE_CUBLAS)
-        ggml_is_contiguous(src0) &&
+    if (ggml_is_contiguous(src0) &&
         ggml_is_contiguous(src1) &&
-#endif
-        ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
+        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
 
         /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
         return true;
@@ -8130,7 +8157,6 @@ static bool ggml_compute_forward_mul_mat_use_blas(
 
     return false;
 }
-
 #endif
 
 static void ggml_compute_forward_mul_mat_f32(
@@ -8146,7 +8172,7 @@ static void ggml_compute_forward_mul_mat_f32(
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
     const int64_t ne10 = src1->ne[0];
 #endif
     const int64_t ne11 = src1->ne[1];
@@ -8203,7 +8229,16 @@ static void ggml_compute_forward_mul_mat_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CUBLAS)
+    if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
+        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
+            ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
+        }
+        return;
+    }
+#endif
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -8217,43 +8252,13 @@ static void ggml_compute_forward_mul_mat_f32(
             return;
         }
 
-#if defined(GGML_USE_CUBLAS)
-        const float alpha = 1.0f;
-        const float beta = 0.0f;
-        const int x_ne = ne01 * ne00;
-        const int y_ne = ne11 * ne10;
-        const int d_ne = ne11 * ne01;
-
-        size_t x_size, y_size, d_size;
-        float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
-        float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
-        float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
-#endif
-
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-#if !defined(GGML_USE_CUBLAS)
                 const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
-#endif
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
-#if defined(GGML_USE_CUBLAS)
-                // copy data to device
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
-
-                // compute
-                CUBLAS_CHECK(
-                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                            ne01, ne11, ne10,
-                            &alpha, d_X, ne00,
-                                    d_Y, ne10,
-                            &beta,  d_D, ne01));
-
-                // copy data to host
-                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
                 // zT = y * xT
                 ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
                         ne11, ne01, ne10,
@@ -8270,12 +8275,6 @@ static void ggml_compute_forward_mul_mat_f32(
 #endif
             }
         }
-#if defined(GGML_USE_CUBLAS)
-        CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
-        ggml_cuda_pool_free(d_X, x_size);
-        ggml_cuda_pool_free(d_Y, y_size);
-        ggml_cuda_pool_free(d_D, d_size);
-#endif
         //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
         return;
@@ -8405,7 +8404,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CUBLAS)
+    if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
+        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
+            ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
+        }
+        return;
+    }
+#endif
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         GGML_ASSERT(nb10 == sizeof(float));
 
@@ -8421,37 +8429,8 @@ static void ggml_compute_forward_mul_mat_f16_f32(
             return;
         }
 
-#if defined(GGML_USE_CUBLAS)
-        const float alpha = 1.0f;
-        const float beta = 0.0f;
-        const int x_ne = ne01 * ne00;
-        const int y_ne = ne11 * ne10;
-        const int d_ne = ne11 * ne01;
-
-        size_t x_size, y_size, d_size;
-        ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
-        ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
-        float       * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
-#endif
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
-#if defined(GGML_USE_CUBLAS)
-                // copy src0 while converting src1
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
-
-                // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
-                ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02);
-                {
-                    size_t id = 0;
-                    for (int64_t i01 = 0; i01 < ne11; ++i01) {
-                        for (int64_t i00 = 0; i00 < ne10; ++i00) {
-                            wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
-                        }
-                    }
-
-                    assert(id*sizeof(ggml_fp16_t) <= params->wsize);
-                }
-#else
                 float * const wdata = params->wdata;
                 {
                     size_t id = 0;
@@ -8463,28 +8442,8 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 
                     assert(id*sizeof(float) <= params->wsize);
                 }
-#endif
 
-#if defined(GGML_USE_CUBLAS)
-                const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
-                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-
-                // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
-
-                // compute
-                CUBLAS_CHECK(
-                    cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                            ne01, ne11, ne10,
-                            &alpha, d_X, CUDA_R_16F, ne00,
-                                    d_Y, CUDA_R_16F, ne10,
-                            &beta,  d_D, CUDA_R_32F, ne01,
-                            CUBLAS_COMPUTE_32F,
-                            CUBLAS_GEMM_DEFAULT));
-
-                // copy data to host
-                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
@@ -8513,12 +8472,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
             }
         }
 
-#if defined(GGML_USE_CUBLAS)
-        CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
-        ggml_cuda_pool_free(d_X, x_size);
-        ggml_cuda_pool_free(d_Y, y_size);
-        ggml_cuda_pool_free(d_D, d_size);
-#endif
         /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
 
         return;
@@ -8671,7 +8624,16 @@ static void ggml_compute_forward_mul_mat_q_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CUBLAS)
+    if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
+        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
+            ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
+        }
+        return;
+    }
+#endif
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -8685,25 +8647,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
             return;
         }
 
-#if defined(GGML_USE_CUBLAS)
-        const float alpha = 1.0f;
-        const float beta = 0.0f;
-        const int x_ne = ne01 * ne00;
-        const int y_ne = ne11 * ne10;
-        const int d_ne = ne11 * ne01;
-
-        size_t x_size, y_size, d_size, q_size;
-        float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
-        float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
-        float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
-        void  * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
-
-        const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type);
-        GGML_ASSERT(dequantize_row_q_cuda != NULL);
-#else
         float * const wdata = params->wdata;
         dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
-#endif
 
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -8711,14 +8656,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
-#if defined(GGML_USE_CUBLAS)
-                // copy and dequantize on device
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2));
-
-                dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2);
-                CUDA_CHECK(cudaGetLastError());
-                CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2));
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
                 const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
 #else
                 {
@@ -8734,24 +8672,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
                 const float * x = wdata;
 #endif
 
-#if defined(GGML_USE_CUBLAS)
-                // copy data to device
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
-
-                // wait for dequantization
-                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0));
-
-                // compute
-                CUBLAS_CHECK(
-                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                            ne01, ne11, ne10,
-                            &alpha, d_X, ne00,
-                                    d_Y, ne10,
-                            &beta,  d_D, ne01));
-
-                // copy data to host
-                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
                 // zT = y * xT
                 ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
                         ne11, ne01, ne10,
@@ -8769,13 +8690,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
             }
         }
 
-#if defined(GGML_USE_CUBLAS)
-        CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
-        ggml_cuda_pool_free(d_X, x_size);
-        ggml_cuda_pool_free(d_Y, y_size);
-        ggml_cuda_pool_free(d_D, d_size);
-        ggml_cuda_pool_free(d_Q, q_size);
-#endif
         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
         return;
@@ -11759,18 +11673,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
                         size_t cur = 0;
 
+#if defined(GGML_USE_CUBLAS)
+                        if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
+                            node->n_tasks = 1; // TODO: this actually is doing nothing
+                                                //       the threads are still spinning
+                            cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
+                        }
+                        else
+#endif
                         if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1; // TODO: this actually is doing nothing
                                                    //       the threads are still spinning
-#if defined(GGML_USE_CUBLAS)
-                                // with cuBLAS, we need memory for the full 3D / 4D data of src1
-                                cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
-#else
                                 // here we need memory just for single 2D matrix from src0
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-#endif
                             } else {
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
                             }
@@ -11779,13 +11696,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 #endif
                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
                             cur = 0;
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1;
                             }
 #endif
                         } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1;
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
@@ -12214,10 +12131,16 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
             snprintf(color, sizeof(color), "white");
         }
 
-        fprintf(fp, "  \"%p\" [ \
-style = filled; fillcolor = %s; shape = record; \
-label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
-                (void *) node, color,
+        fprintf(fp, "  \"%p\" [ "
+                    "style = filled; fillcolor = %s; shape = record; "
+                    "label=\"",
+                (void *) node, color);
+
+        if (strlen(node->name) > 0) {
+            fprintf(fp, "%s |", node->name);
+        }
+
+        fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s",
                 i, node->ne[0], node->ne[1],
                 GGML_OP_SYMBOL[node->op]);
 
@@ -12233,18 +12156,26 @@ label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
 
         snprintf(color, sizeof(color), "pink");
 
+        fprintf(fp, "  \"%p\" [ "
+                    "style = filled; fillcolor = %s; shape = record; "
+                    "label=\"<x>",
+                (void *) node, color);
+
+        if (strlen(node->name) > 0) {
+                fprintf(fp, "%s | ", node->name);
+        }
         if (ggml_nelements(node) == 1) {
-            fprintf(fp, "  \"%p\" [ \
-style = filled; fillcolor = %s; shape = record; \
-label=\"<x>%.1e\"; ]\n",
-                    (void *) node, color, (double)ggml_get_f32_1d(node, 0));
-        } else {
-            fprintf(fp, "  \"%p\" [ \
-style = filled; fillcolor = %s; shape = record; \
-label=\"<x>CONST %d [%" PRId64 ", %" PRId64 "]\"; ]\n",
-                    (void *) node, color,
-                    i, node->ne[0], node->ne[1]);
+            if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
+                fprintf(fp, "%d", ggml_get_i32_1d(node, 0));
+            }
+            else {
+                fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, 0));
+            }
+        }
+        else {
+            fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
         }
+        fprintf(fp, "\"; ]\n");
     }
 
     for (int i = 0; i < gb->n_nodes; i++) {
diff --git a/ggml.h b/ggml.h
index cbaea3ed7e5ad4c3df06598f2f699f067ef0a2cb..508dd69b41713209d2339bd06148fcc96ceebec1 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_MAX_OPT           4
 #define GGML_DEFAULT_N_THREADS 4
 
+#define GGML_ASSERT(x) \
+    do { \
+        if (!(x)) { \
+            fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+            abort(); \
+        } \
+    } while (0)
+
 #ifdef  __cplusplus
 extern "C" {
 #endif
@@ -212,6 +220,9 @@ extern "C" {
     GGML_API float       ggml_fp16_to_fp32(ggml_fp16_t x);
     GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
 
+    GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n);
+    GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n);
+
     struct ggml_object;
     struct ggml_context;
 
@@ -339,7 +350,10 @@ extern "C" {
         int64_t perf_time_us;
 
         void * data;
-        char padding[8];
+
+        char name[32];
+
+        char padding[8]; // TODO: remove and add padding to name?
     };
 
     // computation graph
@@ -399,6 +413,7 @@ extern "C" {
 
     GGML_API bool    ggml_is_quantized(enum ggml_type type);
 
+    // TODO: temporary until model loading of ggml examples is refactored
     GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
 
     // main
@@ -461,6 +476,9 @@ extern "C" {
     GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
     GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 
+    GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor);
+    GGML_API void         ggml_set_name(struct ggml_tensor * tensor, const char * name);
+
     //
     // operations on tensors with backpropagation
     //