]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: mmq CLI option, fixed mmq build issues (#2453)
authorJohannes Gäßler <redacted>
Mon, 31 Jul 2023 13:44:35 +0000 (15:44 +0200)
committerGitHub <redacted>
Mon, 31 Jul 2023 13:44:35 +0000 (15:44 +0200)
CMakeLists.txt
Makefile
README.md
examples/common.cpp
examples/common.h
examples/server/server.cpp
ggml-cuda.cu
ggml-cuda.h
llama.cpp
llama.h

index 57678a302d0977203a62f1637081017240b15f01..4ecb3d5862a0d30a3234a10c1c7d84ada6589dd7 100644 (file)
@@ -68,7 +68,7 @@ option(LLAMA_ACCELERATE                      "llama: enable Accelerate framework
 option(LLAMA_BLAS                            "llama: use BLAS"                                  OFF)
 set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
 option(LLAMA_CUBLAS                          "llama: use CUDA"                                  OFF)
-option(LLAMA_CUDA_CUBLAS                     "llama: use cuBLAS for prompt processing"          OFF)
+#option(LLAMA_CUDA_CUBLAS                     "llama: use cuBLAS for prompt processing"          OFF)
 set(LLAMA_CUDA_MMQ_Y       "64" CACHE STRING "llama: y tile size for mmq CUDA kernels")
 option(LLAMA_CUDA_FORCE_DMMV                 "llama: use dmmv instead of mmvq CUDA kernels"     OFF)
 set(LLAMA_CUDA_DMMV_X      "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
@@ -253,9 +253,9 @@ if (LLAMA_CUBLAS)
         set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
 
         add_compile_definitions(GGML_USE_CUBLAS)
-        if (LLAMA_CUDA_CUBLAS)
-            add_compile_definitions(GGML_CUDA_CUBLAS)
-        endif()
+#        if (LLAMA_CUDA_CUBLAS)
+#            add_compile_definitions(GGML_CUDA_CUBLAS)
+#        endif()
         add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y})
         if (LLAMA_CUDA_FORCE_DMMV)
             add_compile_definitions(GGML_CUDA_FORCE_DMMV)
@@ -277,10 +277,14 @@ if (LLAMA_CUBLAS)
         endif()
 
     if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
+        # 52 == lowest CUDA 12 standard
+        # 60 == f16 CUDA intrinsics
+        # 61 == integer CUDA intrinsics
+        # 70 == (assumed) compute capability at which unrolling a loop in mul_mat_q kernels is faster
         if (LLAMA_CUDA_DMMV_F16)
-            set(CMAKE_CUDA_ARCHITECTURES "60;61") # needed for f16 CUDA intrinsics
+            set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
         else()
-            set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics
+            set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
         endif()
     endif()
     message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
index 616c2d9b88be9d82000f0287fba43fc246a7afe2..ebeadfdd05fd5ed587000d8b546eb3e0a2e89e1d 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -236,9 +236,9 @@ ifdef LLAMA_CUDA_MMQ_Y
 else
        NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64
 endif # LLAMA_CUDA_MMQ_Y
-ifdef LLAMA_CUDA_CUBLAS
-       NVCCFLAGS += -DGGML_CUDA_CUBLAS
-endif # LLAMA_CUDA_CUBLAS
+#ifdef LLAMA_CUDA_CUBLAS
+#      NVCCFLAGS += -DGGML_CUDA_CUBLAS
+#endif # LLAMA_CUDA_CUBLAS
 ifdef LLAMA_CUDA_CCBIN
        NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
 endif
index 42fc42b05ae564954edb3b026ad46b3292c084e1..b231d24b8185944deee075b33cb72a2bd0933d1b 100644 (file)
--- a/README.md
+++ b/README.md
@@ -400,9 +400,11 @@ Building the program with BLAS support may lead to some performance improvements
 
   The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. The following compilation options are also available to tweak performance:
 
+<!---
+  | LLAMA_CUDA_CUBLAS       | Boolean                |   false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). |
+--->
   | Option                  | Legal values           | Default | Description |
   |-------------------------|------------------------|---------|-------------|
-  | LLAMA_CUDA_CUBLAS       | Boolean                |   false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). |
   | LLAMA_CUDA_MMQ_Y        | Positive integer >= 32 |      64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. |
   | LLAMA_CUDA_FORCE_DMMV   | Boolean                |   false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
   | LLAMA_CUDA_DMMV_X       | Positive integer >= 32 |      32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
index fe7308b1787eb9377fea2818902f1180c7392f1a..e6439841dc1455775282df378c11050d9e06890c 100644 (file)
@@ -352,7 +352,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
 #ifdef GGML_USE_CUBLAS
             params.main_gpu = std::stoi(argv[i]);
 #else
-      fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
+            fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
 #endif
         } else if (arg == "--tensor-split" || arg == "-ts") {
             if (++i >= argc) {
@@ -376,13 +376,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 }
             }
 #else
-      fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
+            fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
+#endif // GGML_USE_CUBLAS
+        } else if (arg == "--mul-mat-q" || arg == "-mmq") {
+#ifdef GGML_USE_CUBLAS
+            params.mul_mat_q = true;
+#else
+            fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to use mul_mat_q kernels.\n");
 #endif // GGML_USE_CUBLAS
         } else if (arg == "--low-vram" || arg == "-lv") {
 #ifdef GGML_USE_CUBLAS
             params.low_vram = true;
 #else
-      fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
+            fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
 #endif // GGML_USE_CUBLAS
         } else if (arg == "--no-mmap") {
             params.use_mmap = false;
@@ -585,6 +591,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stdout, "                        how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
     fprintf(stdout, "  -mg i, --main-gpu i   the GPU to use for scratch and small tensors\n" );
     fprintf(stdout, "  -lv, --low-vram       don't allocate VRAM scratch buffer\n" );
+    fprintf(stdout, "  -mmq, --mul-mat-q     use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" );
+    fprintf(stdout, "                        Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" );
+    fprintf(stdout, "                        is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" );
 #endif
     fprintf(stdout, "  --mtest               compute maximum memory usage\n");
     fprintf(stdout, "  --export              export the computation graph to 'llama.ggml'\n");
@@ -637,6 +646,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     lparams.main_gpu        = params.main_gpu;
     lparams.tensor_split    = params.tensor_split;
     lparams.low_vram        = params.low_vram;
+    lparams.mul_mat_q       = params.mul_mat_q;
     lparams.seed            = params.seed;
     lparams.f16_kv          = params.memory_f16;
     lparams.use_mmap        = params.use_mmap;
index 1184f32df50742040d38d5e7164d33ff8d198c26..974484207251dce531df3228c5d9ae87493d8f5f 100644 (file)
@@ -74,6 +74,7 @@ struct gpt_params {
     size_t hellaswag_tasks = 400;   // number of tasks to use when computing the HellaSwag score
 
     bool low_vram          = false; // if true, reduce VRAM usage at the cost of performance
+    bool mul_mat_q         = false; // if true, use experimental mul_mat_q kernels
     bool memory_f16        = true;  // use f16 instead of f32 for memory kv
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
index 83c03065a5d583daf150026f796532c165a8e224..c0725088f10188b3f7969a820e21648362471cee 100644 (file)
@@ -631,6 +631,9 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     fprintf(stdout, "                        how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
     fprintf(stdout, "  -mg i, --main-gpu i   the GPU to use for scratch and small tensors\n");
     fprintf(stdout, "  -lv, --low-vram don't allocate VRAM scratch buffer\n");
+    fprintf(stdout, "  -mmq, --mul-mat-q     use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" );
+    fprintf(stdout, "                        Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" );
+    fprintf(stdout, "                        is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" );
 #endif
     fprintf(stdout, "  -m FNAME, --model FNAME\n");
     fprintf(stdout, "                        model path (default: %s)\n", params.model.c_str());
@@ -827,7 +830,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
                 }
             }
 #else
-            LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.", {});
+            LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {});
 #endif // GGML_USE_CUBLAS
         }
         else if (arg == "--low-vram" || arg == "-lv")
@@ -835,7 +838,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
 #ifdef GGML_USE_CUBLAS
             params.low_vram = true;
 #else
-            fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
+            LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n", {});
+#endif // GGML_USE_CUBLAS
+        }
+        else if (arg == "--mul-mat-q" || arg == "-mmq")
+        {
+#ifdef GGML_USE_CUBLAS
+            params.mul_mat_q = true;
+#else
+            LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to use mul_mat_q kernels.\n", {});
 #endif // GGML_USE_CUBLAS
         }
         else if (arg == "--main-gpu" || arg == "-mg")
index bcdff3640c13dfb9e4f95000ba97a989ecf93254..f11fbe57c11123736bc07ccb2d9dd16092f3b55f 100644 (file)
@@ -3898,10 +3898,9 @@ static size_t g_scratch_offset = 0;
 
 static int g_device_count = -1;
 static int g_main_device = 0;
-#ifndef GGML_CUDA_FORCE_DMMV
 static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
-#endif
 static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
+static bool g_mul_mat_q = false;
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
@@ -3923,9 +3922,7 @@ void ggml_init_cublas() {
             g_tensor_split[id] = total_vram;
             total_vram += prop.totalGlobalMem;
 
-#ifndef GGML_CUDA_FORCE_DMMV
             g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
-#endif
         }
         for (int id = 0; id < g_device_count; ++id) {
             g_tensor_split[id] /= total_vram;
@@ -4278,6 +4275,7 @@ inline void ggml_cuda_op_mul_mat_vec(
 
 #ifdef GGML_CUDA_FORCE_DMMV
     const bool use_mul_mat_vec_q = false;
+    (void) g_compute_capabilities[0];
 #else
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
@@ -5021,12 +5019,14 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
         if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
             ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
         } else {
-#ifdef GGML_CUDA_CUBLAS
-            const bool use_mul_mat_q = false;
-#else
-            const bool use_mul_mat_q = ggml_is_quantized(src0->type);
-#endif // GGML_CUDA_CUBLAS
-            if (use_mul_mat_q) {
+            int min_compute_capability = INT_MAX;
+            for (int id = 0; id < g_device_count; ++id) {
+                if (min_compute_capability > g_compute_capabilities[id]) {
+                    min_compute_capability = g_compute_capabilities[id];
+                }
+            }
+
+            if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
                 ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
             } else {
                 ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
@@ -5320,6 +5320,10 @@ void ggml_cuda_set_main_device(int main_device) {
     }
 }
 
+void ggml_cuda_set_mul_mat_q(bool mul_mat_q) {
+    g_mul_mat_q = mul_mat_q;
+}
+
 void ggml_cuda_set_scratch_size(size_t scratch_size) {
     g_scratch_size = scratch_size;
 }
index 3c1e8deb6a6ddbfbb43e22bd01cc64e560fe313d..72d7afa463d741498af0063f0ccf14e5b9028bf9 100644 (file)
@@ -27,6 +27,7 @@ void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
 void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
 void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
 void   ggml_cuda_set_main_device(int main_device);
+void   ggml_cuda_set_mul_mat_q(bool mul_mat_q);
 void   ggml_cuda_set_scratch_size(size_t scratch_size);
 void   ggml_cuda_free_scratch(void);
 bool   ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
index 50da4274fa2984b7251c21412e224b47b65533dd..d427054dd8d6ae585490fdc63ddb5cd2299c68fa 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -901,6 +901,7 @@ struct llama_context_params llama_context_default_params() {
         /*.progress_callback           =*/ nullptr,
         /*.progress_callback_user_data =*/ nullptr,
         /*.low_vram                    =*/ false,
+        /*.mul_mat_q                   =*/ false,
         /*.f16_kv                      =*/ true,
         /*.logits_all                  =*/ false,
         /*.vocab_only                  =*/ false,
@@ -1028,6 +1029,7 @@ static void llama_model_load_internal(
         int n_gpu_layers,
         int main_gpu,
         const float * tensor_split,
+        const bool mul_mat_q,
         float rope_freq_base,
         float rope_freq_scale,
         bool low_vram,
@@ -1156,9 +1158,11 @@ static void llama_model_load_internal(
     }
 
     (void) main_gpu;
+    (void) mul_mat_q;
 #if defined(GGML_USE_CUBLAS)
     fprintf(stderr, "%s: using CUDA for GPU acceleration\n", __func__);
     ggml_cuda_set_main_device(main_gpu);
+    ggml_cuda_set_mul_mat_q(mul_mat_q);
 #define LLAMA_BACKEND_OFFLOAD       GGML_BACKEND_GPU
 #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
 #elif defined(GGML_USE_CLBLAST)
@@ -1367,6 +1371,7 @@ static bool llama_model_load(
         int n_gpu_layers,
         int main_gpu,
         const float * tensor_split,
+        const bool mul_mat_q,
         float rope_freq_base,
         float rope_freq_scale,
         bool low_vram,
@@ -1377,7 +1382,8 @@ static bool llama_model_load(
         llama_progress_callback progress_callback,
         void *progress_callback_user_data) {
     try {
-        llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
+        llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers,
+                                  main_gpu, tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, low_vram, memory_type,
                                   use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
         return true;
     } catch (const std::exception & err) {
@@ -3192,7 +3198,7 @@ struct llama_model * llama_load_model_from_file(
     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
     if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers,
-                params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
+                params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
                 memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
                 params.progress_callback_user_data)) {
         delete model;
diff --git a/llama.h b/llama.h
index df46f9b9cfa311ea78906e5d4eb1f7c8a651e055..fa1977f2d949241834f3936424174c2133ef3ac0 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -108,6 +108,7 @@ extern "C" {
 
         // Keep the booleans together to avoid misalignment during copy-by-value.
         bool low_vram;   // if true, reduce VRAM usage at the cost of performance
+        bool mul_mat_q;  // if true, use experimental mul_mat_q kernels
         bool f16_kv;     // use fp16 for KV cache
         bool logits_all; // the llama_eval() call computes all logits, not just the last one
         bool vocab_only; // only load the vocabulary, no weights