message(WARNING "Accelerate framework not found")
endif()
endif()
+
if (LLAMA_OPENBLAS)
if (LLAMA_STATIC)
set(BLA_STATIC ON)
if (CUDAToolkit_FOUND)
message(STATUS "cuBLAS found")
+ enable_language(CUDA)
+
+ set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
+
add_compile_definitions(GGML_USE_CUBLAS)
if (LLAMA_STATIC)
message(STATUS "x86 detected")
if (MSVC)
if (LLAMA_AVX512)
- add_compile_options(/arch:AVX512)
+ add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
+ add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
# macros corresponding to the extensions.
# Do it manually.
if (LLAMA_AVX512_VBMI)
- add_compile_definitions(__AVX512VBMI__)
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
endif()
if (LLAMA_AVX512_VNNI)
- add_compile_definitions(__AVX512VNNI__)
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif()
elseif (LLAMA_AVX2)
- add_compile_options(/arch:AVX2)
+ add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
+ add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
elseif (LLAMA_AVX)
- add_compile_options(/arch:AVX)
+ add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
+ add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
endif()
else()
if (LLAMA_F16C)
add_library(ggml OBJECT
ggml.c
- ggml.h)
+ ggml.h
+ ${GGML_CUDA_SOURCES})
target_include_directories(ggml PUBLIC .)
target_compile_features(ggml PUBLIC c_std_11) # don't bump
target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD)
endif()
+if (GGML_CUDA_SOURCES)
+ message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
+ set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF)
+ set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
+ set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES OFF)
+endif()
+
+
#
# programs, examples and tests
#
+# Define the default target now so that it is always the first target
+default: main quantize quantize-stats perplexity embedding vdot
+
ifndef UNAME_S
UNAME_S := $(shell uname -s)
endif
ifdef LLAMA_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64
+ OBJS += ggml-cuda.o
+ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
+ nvcc -arch=native -c -o $@ $<
endif
ifdef LLAMA_GPROF
CFLAGS += -pg
$(info I CXX: $(CXXV))
$(info )
-default: main quantize quantize-stats perplexity embedding vdot
-
#
# Build library
#
clean:
rm -vf *.o main quantize quantize-stats perplexity embedding benchmark-q4_0-matmult
-main: examples/main/main.cpp ggml.o llama.o common.o
+main: examples/main/main.cpp ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
@echo
@echo '==== Run ./main -h for help. ===='
@echo
-quantize: examples/quantize/quantize.cpp ggml.o llama.o
+quantize: examples/quantize/quantize.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
-quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o
+quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
-perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o
+perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
-embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o
+embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
-vdot: pocs/vdot/vdot.cpp ggml.o
+vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
-libllama.so: llama.o ggml.o
+libllama.so: llama.o ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
#
# Tests
#
-benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o
+benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o benchmark-q4_0-matmult $(LDFLAGS)
./benchmark-q4_0-matmult
--- /dev/null
+#include <stdint.h>
+#include <cuda_fp16.h>
+#include "ggml-cuda.h"
+
+typedef uint16_t ggml_fp16_t;
+static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
+
+#define QK4_0 32
+typedef struct {
+ float d; // delta
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
+
+#define QK4_1 32
+typedef struct {
+ float d; // delta
+ float m; // min
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK4_2 16
+typedef struct {
+ __half d; // delta
+ uint8_t qs[QK4_2 / 2]; // nibbles / quants
+} block_q4_2;
+static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
+
+
+static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
+ const block_q4_0 * x = (const block_q4_0 *) vx;
+
+ const int i = blockIdx.x;
+
+ const float d = x[i].d;
+
+ const uint8_t * pp = x[i].qs;
+
+ for (int l = 0; l < QK4_0; l += 2) {
+ const uint8_t vi = pp[l/2];
+
+ const int8_t vi0 = vi & 0xf;
+ const int8_t vi1 = vi >> 4;
+
+ const float v0 = (vi0 - 8)*d;
+ const float v1 = (vi1 - 8)*d;
+
+ y[i*QK4_0 + l + 0] = v0;
+ y[i*QK4_0 + l + 1] = v1;
+ }
+}
+
+static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
+ const block_q4_1 * x = (const block_q4_1 *) vx;
+
+ const int i = blockIdx.x;
+
+ const float d = x[i].d;
+ const float m = x[i].m;
+
+ const uint8_t * pp = x[i].qs;
+
+ for (int l = 0; l < QK4_1; l += 2) {
+ const uint8_t vi = pp[l/2];
+
+ const int8_t vi0 = vi & 0xf;
+ const int8_t vi1 = vi >> 4;
+
+ const float v0 = vi0*d + m;
+ const float v1 = vi1*d + m;
+
+ y[i*QK4_1 + l + 0] = v0;
+ y[i*QK4_1 + l + 1] = v1;
+ }
+}
+
+static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
+ const block_q4_2 * x = (const block_q4_2 *) vx;
+
+ const int i = blockIdx.x;
+
+ const float d = x[i].d;
+
+ const uint8_t * pp = x[i].qs;
+
+ for (int l = 0; l < QK4_2; l += 2) {
+ const uint8_t vi = pp[l/2];
+
+ const int8_t vi0 = vi & 0xf;
+ const int8_t vi1 = vi >> 4;
+
+ const float v0 = (vi0 - 8)*d;
+ const float v1 = (vi1 - 8)*d;
+
+ y[i*QK4_2 + l + 0] = v0;
+ y[i*QK4_2 + l + 1] = v1;
+ }
+}
+
+extern "C" {
+ __host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+ const int nb = k / QK4_0;
+ dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
+ }
+
+ __host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+ const int nb = k / QK4_1;
+ dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
+ }
+
+ __host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+ const int nb = k / QK4_2;
+ dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
+ }
+}
--- /dev/null
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+
+#ifdef __cplusplus
+}
+#endif
#elif defined(GGML_USE_CUBLAS)
#include <cublas_v2.h>
#include <cuda_runtime.h>
-#define CUDA_CHECK(err) \
- do { \
- cudaError_t err_ = (err); \
- if (err_ != cudaSuccess) { \
- printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
- cudaGetErrorString(err_)); \
- exit(1); \
- } \
+#include "ggml-cuda.h"
+
+#define CUDA_CHECK(err) \
+ do { \
+ cudaError_t err_ = (err); \
+ if (err_ != cudaSuccess) { \
+ printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
+ cudaGetErrorString(err_)); \
+ exit(1); \
+ } \
} while (0)
-#define CUBLAS_CHECK(err) \
- do { \
- cublasStatus_t err_ = (err); \
- if (err_ != CUBLAS_STATUS_SUCCESS) { \
- printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
- exit(1); \
- } \
+#define CUBLAS_CHECK(err) \
+ do { \
+ cublasStatus_t err_ = (err); \
+ if (err_ != CUBLAS_STATUS_SUCCESS) { \
+ printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
+ exit(1); \
+ } \
} while (0)
static cublasHandle_t cublasH = NULL;
CUBLAS_CHECK(cublasCreate(&cublasH));
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
+
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
// configure logging to stdout
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else
// zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
}
}
#if defined(GGML_USE_CUBLAS)
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else
const float * x = wdata;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
}
#if defined(GGML_USE_CUBLAS)
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
return;
}
- float * const wdata = params->wdata;
- dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
-
#if defined(GGML_USE_CUBLAS)
float *d_X = NULL;
float *d_Y = NULL;
float *d_D = NULL;
+ float *d_Q = NULL;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne10;
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+ CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
+
+ void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
+ if (type == GGML_TYPE_Q4_0) {
+ dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
+ }
+ else if (type == GGML_TYPE_Q4_1) {
+ dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
+ }
+ else if (type == GGML_TYPE_Q4_2) {
+ dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
+ }
+ else {
+ GGML_ASSERT(false);
+ }
+#else
+ float * const wdata = params->wdata;
+ dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+#if defined(GGML_USE_CUBLAS)
+ // copy and dequantize on device
+ CUDA_CHECK(
+ cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
+ GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
+
+ dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
+ CUDA_CHECK(cudaGetLastError());
+#else
{
size_t id = 0;
for (int64_t i01 = 0; i01 < ne01; ++i01) {
id += ne00;
}
}
-
const float * x = wdata;
- const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+#endif
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
#if defined(GGML_USE_CUBLAS)
// copy data to device
- CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
// compute
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else
// zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
}
#if defined(GGML_USE_CUBLAS)
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
+ CUDA_CHECK(cudaFree(d_Q));
#endif
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);