]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add Q4_3 support to cuBLAS (#1086)
authorslaren <redacted>
Thu, 20 Apr 2023 18:49:53 +0000 (20:49 +0200)
committerGitHub <redacted>
Thu, 20 Apr 2023 18:49:53 +0000 (20:49 +0200)
Makefile
ggml-cuda.cu
ggml-cuda.h

index 8483d66ce4a7d46eecac1eb00f614572c01f2b55..f267d086415eec490bd052f5f1d3876585c8bba0 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -102,7 +102,7 @@ ifdef LLAMA_OPENBLAS
 endif
 ifdef LLAMA_CUBLAS
        CFLAGS  += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
-       LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -lrt -L/usr/local/cuda/lib64
+       LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64
        OBJS    += ggml-cuda.o
 ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
        nvcc -arch=native -c -o $@ $<
index 7cd116602b9b093a8c7b327c5e8c575f0500d340..0baa989a36ca917eac3cea2c021255a7e68f5911 100644 (file)
@@ -22,11 +22,20 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 b
 
 #define QK4_2 16
 typedef struct {
-    __half d;               // delta
+    __half  d;              // delta
     uint8_t qs[QK4_2 / 2];  // nibbles / quants
 } block_q4_2;
 static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
 
+#define QK4_3 16
+typedef struct {
+    __half  d;         // delta
+    __half  m;         // min
+    uint8_t qs[QK4_3 / 2]; // nibbles / quants
+} block_q4_3;
+static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
+
+
 
 static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
     const block_q4_0 * x = (const block_q4_0 *) vx;
@@ -98,6 +107,30 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
     }
 }
 
+static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
+    const block_q4_3 * x = (const block_q4_3 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+    const float m = x[i].m;
+
+    const uint8_t * pp = x[i].qs;
+
+    for (int l = 0; l < QK4_3; l += 2) {
+        const uint8_t vi = pp[l/2];
+
+        const int8_t vi0 = vi & 0xf;
+        const int8_t vi1 = vi >> 4;
+
+        const float v0 = vi0*d + m;
+        const float v1 = vi1*d + m;
+
+        y[i*QK4_3 + l + 0] = v0;
+        y[i*QK4_3 + l + 1] = v1;
+    }
+}
+
 extern "C" {
     __host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
         const int nb = k / QK4_0;
@@ -113,4 +146,9 @@ extern "C" {
         const int nb = k / QK4_2;
         dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
     }
+
+    __host__ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+        const int nb = k / QK4_3;
+        dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
+    }
 }
index 646caafc6f8a0413d48f8917a6712ac42d9bec4c..be140606aa2d456514edde044b66796b74d92abf 100644 (file)
@@ -5,6 +5,7 @@ extern "C" {
 void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 
 #ifdef  __cplusplus
 }