]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : per-layer KV cache + quantum K cache (#4309)
authorGeorgi Gerganov <redacted>
Thu, 7 Dec 2023 11:03:17 +0000 (13:03 +0200)
committerGitHub <redacted>
Thu, 7 Dec 2023 11:03:17 +0000 (13:03 +0200)
* per-layer KV

* remove unnecessary copies

* less code duplication, offload k and v separately

* llama : offload KV cache per-layer

* llama : offload K shift tensors

* llama : offload for rest of the model arches

* llama : enable offload debug temporarily

* llama : keep the KV related layers on the device

* llama : remove mirrors, perform Device -> Host when partial offload

* common : add command-line arg to disable KV cache offloading

* llama : update session save/load

* llama : support quantum K cache (#4312)

* llama : support quantum K cache (wip)

* metal : add F32 -> Q8_0 copy kernel

* cuda : add F32 -> Q8_0 copy kernel

ggml-ci

* cuda : use mmv kernel for quantum cache ops

* llama : pass KV cache type through API

* llama : fix build

ggml-ci

* metal : add F32 -> Q4_0 copy kernel

* metal : add F32 -> Q4_1 copy kernel

* cuda : wip

* cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels

* llama-bench : support type_k/type_v

* metal : use mm kernel only for quantum KV cache

* cuda : add comment

* llama : remove memory_f16 and kv_f16 flags

---------

Co-authored-by: slaren <redacted>
* readme : add API change notice

---------

Co-authored-by: slaren <redacted>
README.md
common/common.cpp
common/common.h
examples/llama-bench/llama-bench.cpp
examples/quantize-stats/quantize-stats.cpp
examples/server/server.cpp
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
llama.cpp
llama.h

index dac971ae5dfe56171d487a99f371e51fd3d4cd1e..ce026b8d1d851cf16a769b332efd412a52687206 100644 (file)
--- a/README.md
+++ b/README.md
@@ -10,6 +10,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
 
 ### Hot topics
 
+- **llama.h API change for handling KV cache offloading and data type: https://github.com/ggerganov/llama.cpp/pull/4309**
 - Using `llama.cpp` with AWS instances: https://github.com/ggerganov/llama.cpp/discussions/4225
 - Looking for contributions to improve and maintain the `server` example: https://github.com/ggerganov/llama.cpp/issues/4216
 - Collecting Apple Silicon performance stats: https://github.com/ggerganov/llama.cpp/discussions/4167
index 4e823c526e2e6b799dde84327eecf64e6978b349..4a61ae5937f6432b5840b1926225294be67c0443 100644 (file)
@@ -278,8 +278,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.yarn_beta_slow = std::stof(argv[i]);
-        } else if (arg == "--memory-f32") {
-            params.memory_f16 = false;
         } else if (arg == "--samplers") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -510,6 +508,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
             params.infill = true;
         } else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
             params.dump_kv_cache = true;
+        } else if (arg == "-nkvo" || arg == "--no-kv-offload") {
+            params.no_kv_offload = true;
+        } else if (arg == "-ctk" || arg == "--cache-type-k") {
+            params.cache_type_k = argv[++i];
+        } else if (arg == "-ctv" || arg == "--cache-type-v") {
+            params.cache_type_v = argv[++i];
         } else if (arg == "--multiline-input") {
             params.multiline_input = true;
         } else if (arg == "--simple-io") {
@@ -858,8 +862,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --yarn-beta-fast N    YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
     printf("  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
     printf("  --no-penalize-nl      do not penalize newline token\n");
-    printf("  --memory-f32          use f32 instead of f16 for memory key+value (default: disabled)\n");
-    printf("                        not recommended: doubles context memory required and no measurable increase in quality\n");
     printf("  --temp N              temperature (default: %.1f)\n", (double)sparams.temp);
     printf("  --logits-all          return logits for all tokens in the batch (default: disabled)\n");
     printf("  --hellaswag           compute HellaSwag score over random tasks from datafile supplied with -f\n");
@@ -900,6 +902,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --verbose-prompt      print prompt before generation\n");
     printf("  -dkvc, --dump-kv-cache\n");
     printf("                        verbose print of the KV cache\n");
+    printf("  -nkvo, --no-kv-offload\n");
+    printf("                        disable KV offload\n");
+    printf("  -ctk TYPE, --cache-type-k TYPE\n");
+    printf("                        KV cache data type for K (default: %s)\n", params.cache_type_k.c_str());
+    printf("  -ctv TYPE, --cache-type-v TYPE\n");
+    printf("                        KV cache data type for V (default: %s)\n", params.cache_type_v.c_str());
     printf("  --simple-io           use basic IO for better compatibility in subprocesses and limited consoles\n");
     printf("  --lora FNAME          apply LoRA adapter (implies --no-mmap)\n");
     printf("  --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
@@ -1015,6 +1023,29 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
     return mparams;
 }
 
+static ggml_type kv_cache_type_from_str(const std::string & s) {
+    if (s == "f16") {
+        return GGML_TYPE_F16;
+    }
+    if (s == "q8_0") {
+        return GGML_TYPE_Q8_0;
+    }
+    if (s == "q4_0") {
+        return GGML_TYPE_Q4_0;
+    }
+    if (s == "q4_1") {
+        return GGML_TYPE_Q4_1;
+    }
+    if (s == "q5_0") {
+        return GGML_TYPE_Q5_0;
+    }
+    if (s == "q5_1") {
+        return GGML_TYPE_Q5_1;
+    }
+
+    throw std::runtime_error("Invalid cache type: " + s);
+}
+
 struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
     auto cparams = llama_context_default_params();
 
@@ -1024,7 +1055,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     cparams.n_threads_batch   = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
     cparams.mul_mat_q         = params.mul_mat_q;
     cparams.seed              = params.seed;
-    cparams.f16_kv            = params.memory_f16;
     cparams.logits_all        = params.logits_all;
     cparams.embedding         = params.embedding;
     cparams.rope_scaling_type = params.rope_scaling_type;
@@ -1035,6 +1065,10 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     cparams.yarn_beta_fast    = params.yarn_beta_fast;
     cparams.yarn_beta_slow    = params.yarn_beta_slow;
     cparams.yarn_orig_ctx     = params.yarn_orig_ctx;
+    cparams.offload_kqv       = !params.no_kv_offload;
+
+    cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
+    cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
 
     return cparams;
 }
@@ -1447,7 +1481,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
     }
     fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
     fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
-    fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
     fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
     fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
     fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
index 02467938061b2e7d213a413b34ee4d9bf460271f..e87ce113398b3edfb1f549c12f63df6736e3ca14 100644 (file)
@@ -100,7 +100,6 @@ struct gpt_params {
     size_t hellaswag_tasks = 400;   // number of tasks to use when computing the HellaSwag score
 
     bool mul_mat_q         = true;  // if true, use mul_mat_q kernels instead of cuBLAS
-    bool memory_f16        = true;  // use f16 instead of f32 for memory kv
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
     bool interactive       = false; // interactive mode
@@ -125,10 +124,14 @@ struct gpt_params {
     bool verbose_prompt    = false; // print prompt tokens before generation
     bool infill            = false; // use infill mode
     bool dump_kv_cache     = false; // dump the KV cache contents for debugging purposes
+    bool no_kv_offload     = false; // disable KV offloading
+
+    std::string cache_type_k = "f16"; // KV cache data type for the K
+    std::string cache_type_v = "f16"; // KV cache data type for the V
 
     // multimodal models (see examples/llava)
     std::string mmproj = ""; // path to multimodal projector
-    std::string image = ""; // path to an image file
+    std::string image  = ""; // path to an image file
 };
 
 bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);
index 9bd82d565834a9dc788725b6ac927bde896a79b5..6617c050ddfec7fb5e1baab713578316bf6f62ac 100644 (file)
@@ -53,6 +53,13 @@ static std::vector<T> split(const std::string & str, char delim) {
     return values;
 }
 
+template<typename T, typename F>
+static std::vector<std::string> transform_to_str(const std::vector<T> & values, F f) {
+    std::vector<std::string> str_values;
+    std::transform(values.begin(), values.end(), std::back_inserter(str_values), f);
+    return str_values;
+}
+
 template<typename T>
 static T avg(const std::vector<T> & v) {
     if (v.empty()) {
@@ -126,7 +133,8 @@ struct cmd_params {
     std::vector<int> n_prompt;
     std::vector<int> n_gen;
     std::vector<int> n_batch;
-    std::vector<bool> f32_kv;
+    std::vector<ggml_type> type_k;
+    std::vector<ggml_type> type_v;
     std::vector<int> n_threads;
     std::vector<int> n_gpu_layers;
     std::vector<int> main_gpu;
@@ -142,7 +150,8 @@ static const cmd_params cmd_params_defaults = {
     /* n_prompt      */ {512},
     /* n_gen         */ {128},
     /* n_batch       */ {512},
-    /* f32_kv        */ {false},
+    /* type_k        */ {GGML_TYPE_F16},
+    /* type_v        */ {GGML_TYPE_F16},
     /* n_threads     */ {get_num_physical_cores()},
     /* n_gpu_layers  */ {99},
     /* main_gpu      */ {0},
@@ -162,7 +171,8 @@ static void print_usage(int /* argc */, char ** argv) {
     printf("  -p, --n-prompt <n>                (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
     printf("  -n, --n-gen <n>                   (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
     printf("  -b, --batch-size <n>              (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
-    printf("  --memory-f32 <0|1>                (default: %s)\n", join(cmd_params_defaults.f32_kv, ",").c_str());
+    printf("  -ctk <t>, --cache-type-k <t>      (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
+    printf("  -ctv <t>, --cache-type-v <t>      (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
     printf("  -t, --threads <n>                 (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
     printf("  -ngl, --n-gpu-layers <n>          (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
     printf("  -mg, --main-gpu <i>               (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
@@ -173,9 +183,32 @@ static void print_usage(int /* argc */, char ** argv) {
     printf("  -v, --verbose                     (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
     printf("\n");
     printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
+}
 
+static ggml_type ggml_type_from_name(const std::string & s) {
+    if (s == "f16") {
+        return GGML_TYPE_F16;
+    }
+    if (s == "q8_0") {
+        return GGML_TYPE_Q8_0;
+    }
+    if (s == "q4_0") {
+        return GGML_TYPE_Q4_0;
+    }
+    if (s == "q4_1") {
+        return GGML_TYPE_Q4_1;
+    }
+    if (s == "q5_0") {
+        return GGML_TYPE_Q5_0;
+    }
+    if (s == "q5_1") {
+        return GGML_TYPE_Q5_1;
+    }
+
+    return GGML_TYPE_COUNT;
 }
 
+
 static cmd_params parse_cmd_params(int argc, char ** argv) {
     cmd_params params;
     std::string arg;
@@ -224,13 +257,38 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
             }
             auto p = split<int>(argv[i], split_delim);
             params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
-        } else if (arg == "--memory-f32") {
+        } else if (arg == "-ctk" || arg == "--cache-type-k") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
-            auto p = split<int>(argv[i], split_delim);
-            params.f32_kv.insert(params.f32_kv.end(), p.begin(), p.end());
+            auto p = split<std::string>(argv[i], split_delim);
+            std::vector<ggml_type> types;
+            for (const auto & t : p) {
+                ggml_type gt = ggml_type_from_name(t);
+                if (gt == GGML_TYPE_COUNT) {
+                    invalid_param = true;
+                    break;
+                }
+                types.push_back(gt);
+            }
+            params.type_k.insert(params.type_k.end(), types.begin(), types.end());
+        } else if (arg == "-ctv" || arg == "--cache-type-v") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            auto p = split<std::string>(argv[i], split_delim);
+            std::vector<ggml_type> types;
+            for (const auto & t : p) {
+                ggml_type gt = ggml_type_from_name(t);
+                if (gt == GGML_TYPE_COUNT) {
+                    invalid_param = true;
+                    break;
+                }
+                types.push_back(gt);
+            }
+            params.type_v.insert(params.type_v.end(), types.begin(), types.end());
         } else if (arg == "-t" || arg == "--threads") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -321,7 +379,8 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
     if (params.n_prompt.empty())     { params.n_prompt = cmd_params_defaults.n_prompt; }
     if (params.n_gen.empty())        { params.n_gen = cmd_params_defaults.n_gen; }
     if (params.n_batch.empty())      { params.n_batch = cmd_params_defaults.n_batch; }
-    if (params.f32_kv.empty())       { params.f32_kv = cmd_params_defaults.f32_kv; }
+    if (params.type_k.empty())       { params.type_k = cmd_params_defaults.type_k; }
+    if (params.type_v.empty())       { params.type_v = cmd_params_defaults.type_v; }
     if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
     if (params.main_gpu.empty())     { params.main_gpu = cmd_params_defaults.main_gpu; }
     if (params.mul_mat_q.empty())    { params.mul_mat_q = cmd_params_defaults.mul_mat_q; }
@@ -336,7 +395,8 @@ struct cmd_params_instance {
     int n_prompt;
     int n_gen;
     int n_batch;
-    bool f32_kv;
+    ggml_type type_k;
+    ggml_type type_v;
     int n_threads;
     int n_gpu_layers;
     int main_gpu;
@@ -365,7 +425,8 @@ struct cmd_params_instance {
 
         cparams.n_ctx = n_prompt + n_gen;
         cparams.n_batch = n_batch;
-        cparams.f16_kv = !f32_kv;
+        cparams.type_k = type_k;
+        cparams.type_v = type_v;
         cparams.mul_mat_q = mul_mat_q;
 
         return cparams;
@@ -380,7 +441,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
     for (const auto & mg : params.main_gpu)
     for (const auto & ts : params.tensor_split)
     for (const auto & nb : params.n_batch)
-    for (const auto & fk : params.f32_kv)
+    for (const auto & tk : params.type_k)
+    for (const auto & tv : params.type_v)
     for (const auto & mmq : params.mul_mat_q)
     for (const auto & nt : params.n_threads) {
         cmd_params_instance instance = {
@@ -388,7 +450,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
             /* .n_prompt     = */ n_prompt,
             /* .n_gen        = */ n_gen,
             /* .n_batch      = */ nb,
-            /* .f32_kv       = */ fk,
+            /* .type_k       = */ tk,
+            /* .type_v       = */ tv,
             /* .n_threads    = */ nt,
             /* .n_gpu_layers = */ nl,
             /* .main_gpu     = */ mg,
@@ -410,7 +473,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
     for (const auto & mg : params.main_gpu)
     for (const auto & ts : params.tensor_split)
     for (const auto & nb : params.n_batch)
-    for (const auto & fk : params.f32_kv)
+    for (const auto & tk : params.type_k)
+    for (const auto & tv : params.type_v)
     for (const auto & mmq : params.mul_mat_q)
     for (const auto & nt : params.n_threads) {
         for (const auto & n_prompt : params.n_prompt) {
@@ -422,7 +486,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
                 /* .n_prompt     = */ n_prompt,
                 /* .n_gen        = */ 0,
                 /* .n_batch      = */ nb,
-                /* .f32_kv       = */ fk,
+                /* .type_k       = */ tk,
+                /* .type_v       = */ tv,
                 /* .n_threads    = */ nt,
                 /* .n_gpu_layers = */ nl,
                 /* .main_gpu     = */ mg,
@@ -441,7 +506,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
                 /* .n_prompt     = */ 0,
                 /* .n_gen        = */ n_gen,
                 /* .n_batch      = */ nb,
-                /* .f32_kv       = */ fk,
+                /* .type_k       = */ tk,
+                /* .type_v       = */ tv,
                 /* .n_threads    = */ nt,
                 /* .n_gpu_layers = */ nl,
                 /* .main_gpu     = */ mg,
@@ -489,7 +555,8 @@ struct test {
     uint64_t model_n_params;
     int n_batch;
     int n_threads;
-    bool f32_kv;
+    ggml_type type_k;
+    ggml_type type_v;
     int n_gpu_layers;
     int main_gpu;
     bool mul_mat_q;
@@ -508,7 +575,8 @@ struct test {
         model_n_params = llama_model_n_params(lmodel);
         n_batch = inst.n_batch;
         n_threads = inst.n_threads;
-        f32_kv = inst.f32_kv;
+        type_k = inst.type_k;
+        type_v = inst.type_v;
         n_gpu_layers = inst.n_gpu_layers;
         main_gpu = inst.main_gpu;
         mul_mat_q = inst.mul_mat_q;
@@ -571,7 +639,7 @@ struct test {
             "cuda", "opencl", "metal", "gpu_blas", "blas",
             "cpu_info", "gpu_info",
             "model_filename", "model_type", "model_size", "model_n_params",
-            "n_batch", "n_threads", "f16_kv",
+            "n_batch", "n_threads", "type_k", "type_v",
             "n_gpu_layers", "main_gpu", "mul_mat_q", "tensor_split",
             "n_prompt", "n_gen", "test_time",
             "avg_ns", "stddev_ns",
@@ -621,7 +689,7 @@ struct test {
             std::to_string(cuda), std::to_string(opencl), std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas),
             cpu_info, gpu_info,
             model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
-            std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv),
+            std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
             std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), tensor_split_str,
             std::to_string(n_prompt), std::to_string(n_gen), test_time,
             std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -805,8 +873,11 @@ struct markdown_printer : public printer {
         if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
             fields.push_back("n_batch");
         }
-        if (params.f32_kv.size() > 1 || params.f32_kv != cmd_params_defaults.f32_kv) {
-            fields.push_back("f16_kv");
+        if (params.type_k.size() > 1 || params.type_k != cmd_params_defaults.type_k) {
+            fields.push_back("type_k");
+        }
+        if (params.type_v.size() > 1 || params.type_v != cmd_params_defaults.type_v) {
+            fields.push_back("type_v");
         }
         if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) {
             fields.push_back("main_gpu");
index 2712824774ae74fba24ee2ffd06bbfaa81ad919a..773024160f839b517953e208ffb4ac66b226213e 100644 (file)
@@ -321,7 +321,6 @@ int main(int argc, char ** argv) {
         auto cparams = llama_context_default_params();
         cparams.n_ctx      = 256;
         cparams.seed       = 1;
-        cparams.f16_kv     = false;
 
         ctx = llama_new_context_with_model(model, cparams);
 
index 369f81a8428b2c44ed00a552fbcaabd9bfdf5db8..895f751c9fdbbe8e613f05c028d892d50c300b43 100644 (file)
@@ -2108,10 +2108,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             }
             params.yarn_beta_slow = std::stof(argv[i]);
         }
-        else if (arg == "--memory-f32" || arg == "--memory_f32")
-        {
-            params.memory_f16 = false;
-        }
         else if (arg == "--threads" || arg == "-t")
         {
             if (++i >= argc)
index 9019a849f0bff14945251c63145e52f3d05d6455..1200d1c888b42f086a63a8863a10b576de86dd44 100644 (file)
@@ -7,6 +7,7 @@
 #include <stdio.h>
 #include <atomic>
 #include <assert.h>
+#include <float.h>
 
 #if defined(GGML_USE_HIPBLAS)
 #include <hip/hip_runtime.h>
@@ -4559,6 +4560,116 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
     cpy_1(cx + x_offset, cdst + dst_offset);
 }
 
+static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q8_0 * dsti = (block_q8_0 *) cdsti;
+
+    float amax = 0.0f; // absolute max
+
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = xi[j];
+        amax = fmaxf(amax, fabsf(v));
+    }
+
+    const float d = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = xi[j]*id;
+
+        dsti->qs[j] = roundf(x0);
+    }
+}
+
+static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_0 * dsti = (block_q4_0 *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK4_0; ++j) {
+        const float v = xi[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    const float d  = vmax / -8;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    for (int j = 0; j < QK4_0/2; ++j) {
+        const float x0 = xi[0       + j]*id;
+        const float x1 = xi[QK4_0/2 + j]*id;
+
+        const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
+        const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
+
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
+
+static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_1 * dsti = (block_q4_1 *) cdsti;
+
+    float vmin = FLT_MAX;
+    float vmax = -FLT_MAX;
+
+    for (int j = 0; j < QK4_1; ++j) {
+        const float v = xi[j];
+
+        if (v < vmin) vmin = v;
+        if (v > vmax) vmax = v;
+    }
+
+    const float d  = (vmax - vmin) / ((1 << 4) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->dm.x = d;
+    dsti->dm.y = vmin;
+
+    for (int j = 0; j < QK4_1/2; ++j) {
+        const float x0 = (xi[0       + j] - vmin)*id;
+        const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
+
+        const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
+        const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
+
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
+
+template <cpy_kernel_t cpy_blck, int qk>
+static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
+                                 const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+                                 const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
+    const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+    if (i >= ne) {
+        return;
+    }
+
+    const int i02 = i / (ne00*ne01);
+    const int i01 = (i - i02*ne01*ne00) / ne00;
+    const int i00 = (i - i02*ne01*ne00 - i01*ne00);
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
+
+    const int i12 = i / (ne10*ne11);
+    const int i11 = (i - i12*ne10*ne11) / ne10;
+    const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
+    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
+
+    cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
 static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
     const float y = (i0 / 2 - low) / max(0.001f, high - low);
     return 1.0f - min(1.0f, max(0.0f, y));
@@ -5737,6 +5848,39 @@ static void ggml_cpy_f32_f16_cuda(
         (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
 }
 
+static void ggml_cpy_f32_q8_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK8_0 == 0);
+    const int num_blocks = ne / QK8_0;
+    cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void ggml_cpy_f32_q4_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_0 == 0);
+    const int num_blocks = ne / QK4_0;
+    cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void ggml_cpy_f32_q4_1_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_1 == 0);
+    const int num_blocks = ne / QK4_1;
+    cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
 static void ggml_cpy_f16_f16_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -6093,20 +6237,21 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
     const enum ggml_type type = src->type;
     const int64_t ts = ggml_type_size(type);
     const int64_t bs = ggml_blck_size(type);
-    int64_t i1_diff = i1_high - i1_low;
+    const int64_t i1_diff = i1_high - i1_low;
 
     const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
-    if (nb0 == ts && nb1 == ts*ne0/bs) {
+    if (nb0 == ts && nb1 == ts*(ne0/bs)) {
         return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
     }
     if (nb0 == ts) {
-        return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
+        return cudaMemcpy2DAsync(dst_ptr, ts*(ne0/bs), x, nb1, ts*(ne0/bs), i1_diff, kind, stream);
     }
+    GGML_ASSERT(bs == 1 && "TODO: implement bs != 1");
     for (int64_t i1 = 0; i1 < i1_diff; i1++) {
         const void * rx = (const void *) ((const char *) x + i1*nb1);
-        void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+        void * rd = (void *) (dst_ptr + i1*ts*ne0);
         // pretend the row is a matrix with cols=1
-        cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
+        cudaError_t r = cudaMemcpy2DAsync(rd, ts, rx, nb0, ts, ne0, kind, stream);
         if (r != cudaSuccess) { return r; }
     }
     return cudaSuccess;
@@ -6474,6 +6619,8 @@ inline void ggml_cuda_op_mul_mat_vec_q(
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, const cudaStream_t & stream) {
 
+    GGML_ASSERT(ggml_nrows(src1) == 1);
+
     const int64_t ne00 = src0->ne[0];
     const int64_t row_diff = row_high - row_low;
 
@@ -6533,7 +6680,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
     size_t ash;
     dfloat * src1_dfloat = nullptr; // dfloat == half
 
-    bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+    bool src1_convert_f16 =
+        src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
         src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
         src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
 
@@ -7103,10 +7251,9 @@ static void ggml_cuda_op_mul_mat(
 
     const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
     const bool src0_is_contiguous = ggml_is_contiguous(src0);
-
     const bool src1_is_contiguous = ggml_is_contiguous(src1);
-    const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ?
-        ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
+
+    const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
 
     const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
     GGML_ASSERT(!(split && ne02 > 1));
@@ -7231,7 +7378,7 @@ static void ggml_cuda_op_mul_mat(
                 const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
 
                 // for split tensors the data begins at i0 == i0_offset_low
-                char  *  src0_dd_i =  src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs;
+                char  *  src0_dd_i =  src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
                 float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10;
                 char  * src1_ddq_i = src1_ddq[id] +  src1_ddq_i_offset;
                 float *   dst_dd_i =   dst_dd[id] + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
@@ -7698,10 +7845,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
 #ifdef GGML_CUDA_FORCE_DMMV
             const bool use_mul_mat_vec_q = false;
 #else
-            const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
+            const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
 #endif // GGML_CUDA_FORCE_DMMV
 
             if (use_mul_mat_vec_q) {
+                // NOTE: this kernel does not support ggml_nrows(src1) > 1
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
             } else {
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
@@ -7770,14 +7918,17 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
     char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
 
     if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
     } else {
         fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7788,6 +7939,7 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
 }
 
 static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    // TODO: why do we pass dst as src1 here?
     ggml_cuda_cpy(src0, dst, nullptr);
     (void) src1;
 }
index 3343bc8a3af37f200ffb3a493831b09054b75be4..be4ab0f2ed47c87d4b7d6043cee3787ef441b604 100644 (file)
@@ -118,6 +118,11 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(im2col_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
+    GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
+    GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
+    GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
+    //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
+    //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
     GGML_METAL_DECL_KERNEL(concat);
     GGML_METAL_DECL_KERNEL(sqr);
@@ -324,6 +329,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(im2col_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
+        GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
+        GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
+        GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
+        //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
+        //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
         GGML_METAL_ADD_KERNEL(concat);
         GGML_METAL_ADD_KERNEL(sqr);
@@ -425,6 +435,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(im2col_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
+    GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
+    GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
+    GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
+    //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
+    //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
     GGML_METAL_DEL_KERNEL(concat);
     GGML_METAL_DEL_KERNEL(sqr);
@@ -1114,7 +1129,7 @@ void ggml_metal_graph_compute(
                                 !ggml_is_transposed(src1) &&
                                 src1t == GGML_TYPE_F32 &&
                                 ne00 % 32 == 0 && ne00 >= 64 &&
-                                ne11 > ne11_mm_min) {
+                                (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
                                 //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
                                 switch (src0->type) {
                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break;
@@ -1549,14 +1564,23 @@ void ggml_metal_graph_compute(
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
                         {
-                            const int nth = MIN(1024, ne00);
+                            GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+
+                            int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
 
                             switch (src0t) {
                                 case GGML_TYPE_F32:
                                     {
+                                        GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
+
                                         switch (dstt) {
-                                            case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
-                                            case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
+                                            case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16];  break;
+                                            case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];  break;
+                                            case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
+                                            case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
+                                            case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
+                                            //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
+                                            //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
                                             default: GGML_ASSERT(false && "not implemented");
                                         };
                                     } break;
index 9a79f815f3a72088665c711405966c9be26cce21..9f5ffcbafe8fc07ae3fcf0485bd6b89c96ad1efa 100644 (file)
@@ -3,6 +3,7 @@
 using namespace metal;
 
 #define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define MIN(x, y) ((x) < (y) ? (x) : (y))
 
 #define QK4_0 32
 #define QR4_0 2
@@ -1460,6 +1461,197 @@ kernel void kernel_cpy_f32_f32(
     }
 }
 
+kernel void kernel_cpy_f32_q8_0(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
+
+    device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_0; j++) {
+            const float v = src[j];
+            amax = MAX(amax, fabs(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK8_0].d = d;
+
+        for (int j = 0; j < QK8_0; ++j) {
+            const float x0 = src[j]*id;
+
+            dst_data[i00/QK8_0].qs[j] = round(x0);
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q4_0(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
+
+    device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK4_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / -8;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK4_0].d = d;
+
+        for (int j = 0; j < QK4_0/2; ++j) {
+            const float x0 = src[0       + j]*id;
+            const float x1 = src[QK4_0/2 + j]*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+            dst_data[i00/QK4_0].qs[j]  = xi0;
+            dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q4_1(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
+
+    device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < QK4_1; j++) {
+            const float v = src[j];
+            if (min > v) min = v;
+            if (max < v) max = v;
+        }
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK4_1].d = d;
+        dst_data[i00/QK4_1].m = min;
+
+        for (int j = 0; j < QK4_1/2; ++j) {
+            const float x0 = (src[0       + j] - min)*id;
+            const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+            dst_data[i00/QK4_1].qs[j]  = xi0;
+            dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
 kernel void kernel_concat(
     device const char * src0,
     device const char * src1,
index 14e5d312e6ffc08fe4f8f79cbcb898166c9f18ce..b12bbd1b054422a4ef4a5492b2a39eca357a6a41 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1231,6 +1231,7 @@ struct llama_cparams {
     float yarn_beta_slow;
 
     bool mul_mat_q;
+    bool offload_kqv;
 };
 
 struct llama_layer {
@@ -1299,8 +1300,8 @@ struct llama_kv_cache {
 
     std::vector<llama_kv_cell> cells;
 
-    struct ggml_tensor * k = NULL;
-    struct ggml_tensor * v = NULL;
+    std::vector<struct ggml_tensor *> k_l; // per layer
+    std::vector<struct ggml_tensor *> v_l;
 
     struct ggml_context * ctx = NULL;
 
@@ -1313,8 +1314,10 @@ struct llama_kv_cache {
 
 #ifdef GGML_USE_CUBLAS
         if (ggml_cublas_loaded()) {
-            ggml_cuda_free_data(k);
-            ggml_cuda_free_data(v);
+            for (size_t i = 0; i < k_l.size(); ++i) {
+                ggml_cuda_free_data(k_l[i]);
+                ggml_cuda_free_data(v_l[i]);
+            }
         }
 #endif
     }
@@ -1504,9 +1507,11 @@ struct llama_context {
 static bool llama_kv_cache_init(
         const struct llama_hparams & hparams,
              struct llama_kv_cache & cache,
-                         ggml_type   wtype,
+                         ggml_type   ktype,
+                         ggml_type   vtype,
                           uint32_t   n_ctx,
-                               int   n_gpu_layers) {
+                               int   n_gpu_layers,
+                              bool   offload) {
     const uint32_t n_embd  = hparams.n_embd_gqa();
     const uint32_t n_layer = hparams.n_layer;
 
@@ -1522,7 +1527,7 @@ static bool llama_kv_cache_init(
     cache.cells.clear();
     cache.cells.resize(n_ctx);
 
-    cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
+    cache.buf.resize(n_elements*(ggml_type_sizef(ktype) + ggml_type_sizef(vtype)) + 2u*n_layer*ggml_tensor_overhead());
     memset(cache.buf.data, 0, cache.buf.size);
 
     struct ggml_init_params params;
@@ -1532,37 +1537,44 @@ static bool llama_kv_cache_init(
 
     cache.ctx = ggml_init(params);
 
+    size_t vram_kv_cache = 0;
+
     if (!cache.ctx) {
         LLAMA_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
         return false;
     }
 
-    cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
-    cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
-    ggml_set_name(cache.k, "cache_k");
-    ggml_set_name(cache.v, "cache_v");
+    cache.k_l.reserve(n_layer);
+    cache.v_l.reserve(n_layer);
 
-    (void) n_gpu_layers;
+    const int i_gpu_start = (int) n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start);
 
-#ifdef GGML_USE_CUBLAS
-    if (ggml_cublas_loaded()) {
-        size_t vram_kv_cache = 0;
+    GGML_UNUSED(offload);
 
-        if (n_gpu_layers > (int)n_layer + 1) {
-            ggml_cuda_assign_buffers_no_scratch(cache.v);
-            LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
-            vram_kv_cache += ggml_nbytes(cache.v);
-        }
-        if (n_gpu_layers > (int)n_layer + 2) {
-            ggml_cuda_assign_buffers_no_scratch(cache.k);
-            LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
-            vram_kv_cache += ggml_nbytes(cache.k);
-        }
-        if (vram_kv_cache > 0) {
-            LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MiB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
+    for (int i = 0; i < (int) n_layer; i++) {
+        ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, ktype, n_embd*n_ctx);
+        ggml_tensor * v = ggml_new_tensor_1d(cache.ctx, vtype, n_embd*n_ctx);
+        ggml_format_name(k, "cache_k_l%d", i);
+        ggml_format_name(v, "cache_v_l%d", i);
+        cache.k_l.push_back(k);
+        cache.v_l.push_back(v);
+#ifdef GGML_USE_CUBLAS
+        if (i >= i_gpu_start) {
+            if (offload) {
+                ggml_cuda_assign_buffers_no_scratch(k);
+                vram_kv_cache += ggml_nbytes(k);
+                ggml_cuda_assign_buffers_no_scratch(v);
+                vram_kv_cache += ggml_nbytes(v);
+            }
         }
+#endif // GGML_USE_CUBLAS
     }
-#endif
+
+    if (vram_kv_cache > 0) {
+        LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
+    }
+
+    GGML_UNUSED(n_gpu_layers);
 
     return true;
 }
@@ -2968,14 +2980,7 @@ static void llm_load_tensors(
                         ggml_backend_type backend_output;
 
                         if (n_gpu_layers > int(n_layer)) {
-                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
-                            // on Windows however this is detrimental unless everything is on the GPU
-#ifndef _WIN32
-                            backend_norm = llama_backend_offload;
-#else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
-#endif // _WIN32
-
+                            backend_norm   = llama_backend_offload;
                             backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
@@ -3045,14 +3050,7 @@ static void llm_load_tensors(
                         ggml_backend_type backend_output;
 
                         if (n_gpu_layers > int(n_layer)) {
-                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
-                            // on Windows however this is detrimental unless everything is on the GPU
-#ifndef _WIN32
-                            backend_norm = llama_backend_offload;
-#else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
-#endif // _WIN32
-
+                            backend_norm   = llama_backend_offload;
                             backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
@@ -3115,14 +3113,7 @@ static void llm_load_tensors(
                         ggml_backend_type backend_output;
 
                         if (n_gpu_layers > int(n_layer)) {
-                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
-                            // on Windows however this is detrimental unless everything is on the GPU
-#ifndef _WIN32
-                            backend_norm = llama_backend_offload;
-#else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
-#endif // _WIN32
-
+                            backend_norm   = llama_backend_offload;
                             backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
@@ -3192,14 +3183,7 @@ static void llm_load_tensors(
                         ggml_backend_type backend_output;
 
                         if (n_gpu_layers > int(n_layer)) {
-                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
-                            // on Windows however this is detrimental unless everything is on the GPU
-#ifndef _WIN32
-                            backend_norm = llama_backend_offload;
-#else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
-#endif // _WIN32
-
+                            backend_norm   = llama_backend_offload;
                             backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
@@ -3269,21 +3253,7 @@ static void llm_load_tensors(
                         ggml_backend_type backend_output;
 
                         if (n_gpu_layers > int(n_layer)) {
-#ifdef GGML_USE_CUBLAS
-                            if (n_gpu_layers > int(n_layer + 1)) {
-                                LLAMA_LOG_ERROR("%s: CUDA backend missing Persimmon CUDA ops, can offload at most %ld layers. See: https://github.com/ggerganov/llama.cpp/issues/4038\n",
-                                    __func__, n_layer + 1);
-                                throw std::runtime_error("Persimmon CUDA offload failed");
-                            }
-#endif
-                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
-                            // on Windows however this is detrimental unless everything is on the GPU
-#ifndef _WIN32
-                            backend_norm = llama_backend_offload;
-#else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
-#endif // _WIN32
-
+                            backend_norm   = llama_backend_offload;
                             backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
@@ -3342,14 +3312,7 @@ static void llm_load_tensors(
                         ggml_backend_type backend_output;
 
                         if (n_gpu_layers > int(n_layer)) {
-                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
-                            // on Windows however this is detrimental unless everything is on the GPU
-#ifndef _WIN32
-                            backend_norm = llama_backend_offload;
-#else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
-#endif // _WIN32
-
+                            backend_norm   = llama_backend_offload;
                             backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
@@ -3420,14 +3383,7 @@ static void llm_load_tensors(
                         ggml_backend_type backend_output;
 
                         if (n_gpu_layers > int(n_layer)) {
-                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
-                            // on Windows however this is detrimental unless everything is on the GPU
-#ifndef _WIN32
-                            backend_norm = llama_backend_offload;
-#else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
-#endif // _WIN32
-
+                            backend_norm   = llama_backend_offload;
                             backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
@@ -3487,14 +3443,7 @@ static void llm_load_tensors(
                         ggml_backend_type backend_output;
 
                         if (n_gpu_layers > int(n_layer)) {
-                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
-                            // on Windows however this is detrimental unless everything is on the GPU
-#ifndef _WIN32
-                            backend_norm = llama_backend_offload;
-#else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
-#endif // _WIN32
-
+                            backend_norm   = llama_backend_offload;
                             backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
@@ -3559,14 +3508,7 @@ static void llm_load_tensors(
                         ggml_backend_type backend_output;
 
                         if (n_gpu_layers > int(n_layer)) {
-                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
-                            // on Windows however this is detrimental unless everything is on the GPU
-#ifndef _WIN32
-                            backend_norm = llama_backend_offload;
-#else
-                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
-#endif // _WIN32
-
+                            backend_norm   = llama_backend_offload;
                             backend_output = llama_backend_offload_split;
                         } else {
                             backend_norm   = GGML_BACKEND_CPU;
@@ -3642,8 +3584,8 @@ static void llm_load_tensors(
         }
 
 #ifdef GGML_USE_CUBLAS
-        const int max_backend_supported_layers = hparams.n_layer + 3;
-        const int max_offloadable_layers       = hparams.n_layer + 3;
+        const int max_backend_supported_layers = hparams.n_layer + 1;
+        const int max_offloadable_layers       = hparams.n_layer + 1;
 #elif GGML_USE_CLBLAST
         const int max_backend_supported_layers = hparams.n_layer + 1;
         const int max_offloadable_layers       = hparams.n_layer + 1;
@@ -3811,11 +3753,11 @@ static void llm_build_k_shift(
         struct ggml_tensor * tmp =
             // we rotate only the first n_rot dimensions
             ggml_rope_custom_inplace(ctx,
-                    ggml_view_3d(ctx, kv.k,
+                    ggml_view_3d(ctx, kv.k_l[il],
                         n_embd_head, n_head_kv, n_ctx,
-                        ggml_element_size(kv.k)*n_embd_head,
-                        ggml_element_size(kv.k)*n_embd_gqa,
-                        ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
+                        ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
+                        ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
+                        0),
                     K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow);
         cb(tmp, "K_shifted", il);
@@ -3842,13 +3784,13 @@ static void llm_build_kv_store(
     //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
     cb(v_cur_t, "v_cur_t", il);
 
-    struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k, n_tokens*n_embd_gqa,
-            (ggml_element_size(kv.k)*n_embd_gqa)*(il*n_ctx + kv_head));
+    struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_gqa,
+            (ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa)*kv_head);
     cb(k_cache_view, "k_cache_view", il);
 
-    struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v, n_tokens, n_embd_gqa,
-            (   n_ctx)*ggml_element_size(kv.v),
-            (il*n_ctx)*ggml_element_size(kv.v)*n_embd_gqa + kv_head*ggml_element_size(kv.v));
+    struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_gqa,
+            (  n_ctx)*ggml_element_size(kv.v_l[il]),
+            (kv_head)*ggml_element_size(kv.v_l[il]));
     cb(v_cache_view, "v_cache_view", il);
 
     // important: storing RoPE-ed version of K in the KV cache!
@@ -4000,11 +3942,11 @@ static struct ggml_tensor * llm_build_kqv(
     cb(q, "q", il);
 
     struct ggml_tensor * k =
-        ggml_view_3d(ctx, kv.k,
+        ggml_view_3d(ctx, kv.k_l[il],
                 n_embd_head, n_kv, n_head_kv,
-                ggml_element_size(kv.k)*n_embd_gqa,
-                ggml_element_size(kv.k)*n_embd_head,
-                ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il);
+                ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
+                ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
+                0);
     cb(k, "k", il);
 
     struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
@@ -4035,11 +3977,11 @@ static struct ggml_tensor * llm_build_kqv(
 
     // split cached v into n_head heads
     struct ggml_tensor * v =
-        ggml_view_3d(ctx, kv.v,
+        ggml_view_3d(ctx, kv.v_l[il],
                 n_kv, n_embd_head, n_head_kv,
-                ggml_element_size(kv.v)*n_ctx,
-                ggml_element_size(kv.v)*n_ctx*n_embd_head,
-                ggml_element_size(kv.v)*n_ctx*n_embd_gqa*il);
+                ggml_element_size(kv.v_l[il])*n_ctx,
+                ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head,
+                0);
     cb(v, "v", il);
 
     struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
@@ -4631,6 +4573,7 @@ struct llm_build_context {
         inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
         cb(inpL, "imp_embd", -1);
 
+        // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
         cb(inp_pos, "inp_pos", -1);
 
@@ -4638,6 +4581,7 @@ struct llm_build_context {
         struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
         cb(KQ_scale, "KQ_scale", -1);
 
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
 
@@ -5237,15 +5181,15 @@ struct llm_build_context {
         cb(inpL, "inp_embd", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_pos= ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
         cb(inp_pos, "inp_pos", -1);
 
         // KQ_scale
-        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        struct ggml_tensor * KQ_scale= ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
         cb(KQ_scale, "KQ_scale", -1);
 
-        // KQ_mask (mask for 1 head, it wil be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask= ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
 
         // shift the entire K-cache if needed
@@ -5351,8 +5295,8 @@ struct llm_build_context {
 enum llm_offload_func_e {
     OFFLOAD_FUNC_NOP,
     OFFLOAD_FUNC,
-    OFFLOAD_FUNC_KQ,
-    OFFLOAD_FUNC_V,
+    OFFLOAD_FUNC_FRC, // force offload
+    OFFLOAD_FUNC_KQV,
     OFFLOAD_FUNC_NR,
     OFFLOAD_FUNC_EMB,
     OFFLOAD_FUNC_OUT,
@@ -5438,11 +5382,12 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
   //{ "inp_embd",                   OFFLOAD_FUNC_NR  }, // TODO: missing K-quants get_rows kernel
     { "pos_embd",                   OFFLOAD_FUNC_NR  },
 
-    { "inp_pos",                    OFFLOAD_FUNC_KQ  }, // this is often used for KQ ops (e.g. rope)
-    { "KQ_scale",                   OFFLOAD_FUNC_KQ  },
-    { "KQ_mask",                    OFFLOAD_FUNC_KQ  },
-    { "K_shift",                    OFFLOAD_FUNC_KQ  },
-    { "K_shifted",                  OFFLOAD_FUNC_KQ  },
+    { "inp_pos",                    OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
+    { "KQ_scale",                   OFFLOAD_FUNC_FRC },
+    { "KQ_mask",                    OFFLOAD_FUNC_FRC },
+    { "K_shift",                    OFFLOAD_FUNC_FRC },
+
+    { "K_shifted",                  OFFLOAD_FUNC     },
 
     { "inp_norm",                   OFFLOAD_FUNC_NR  },
     { "inp_norm_w",                 OFFLOAD_FUNC_NR  },
@@ -5455,38 +5400,38 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
     { "attn_norm",                  OFFLOAD_FUNC     },
     { "attn_norm_2",                OFFLOAD_FUNC     },
 
-    { "wqkv",                       OFFLOAD_FUNC_KQ  },
-    { "bqkv",                       OFFLOAD_FUNC_KQ  },
-    { "wqkv_clamped",               OFFLOAD_FUNC_KQ  },
-
-    { "tmpk",                       OFFLOAD_FUNC_KQ  },
-    { "tmpq",                       OFFLOAD_FUNC_KQ  },
-    { "tmpv",                       OFFLOAD_FUNC_V   },
-    { "Kcur",                       OFFLOAD_FUNC_KQ  },
-    { "Qcur",                       OFFLOAD_FUNC_KQ  },
-    { "Vcur",                       OFFLOAD_FUNC_V   },
-
-    { "krot",                       OFFLOAD_FUNC_KQ  },
-    { "qrot",                       OFFLOAD_FUNC_KQ  },
-    { "kpass",                      OFFLOAD_FUNC_KQ  },
-    { "qpass",                      OFFLOAD_FUNC_KQ  },
-    { "krotated",                   OFFLOAD_FUNC_KQ  },
-    { "qrotated",                   OFFLOAD_FUNC_KQ  },
-
-    { "q",                          OFFLOAD_FUNC_KQ  },
-    { "k",                          OFFLOAD_FUNC_KQ  },
-    { "kq",                         OFFLOAD_FUNC_KQ  },
-    { "kq_scaled",                  OFFLOAD_FUNC_KQ  },
-    { "kq_scaled_alibi",            OFFLOAD_FUNC_KQ  },
-    { "kq_masked",                  OFFLOAD_FUNC_KQ  },
-    { "kq_soft_max",                OFFLOAD_FUNC_V   },
-    { "kq_soft_max_ext",            OFFLOAD_FUNC_V   },
-    { "v",                          OFFLOAD_FUNC_V   },
-    { "kqv",                        OFFLOAD_FUNC_V   },
-    { "kqv_merged",                 OFFLOAD_FUNC_V   },
-    { "kqv_merged_cont",            OFFLOAD_FUNC_V   },
-    { "kqv_wo",                     OFFLOAD_FUNC_V   },
-    { "kqv_out",                    OFFLOAD_FUNC_V   },
+    { "wqkv",                       OFFLOAD_FUNC_KQV },
+    { "bqkv",                       OFFLOAD_FUNC_KQV },
+    { "wqkv_clamped",               OFFLOAD_FUNC_KQV },
+
+    { "tmpk",                       OFFLOAD_FUNC_KQV },
+    { "tmpq",                       OFFLOAD_FUNC_KQV },
+    { "tmpv",                       OFFLOAD_FUNC_KQV },
+    { "Kcur",                       OFFLOAD_FUNC_KQV },
+    { "Qcur",                       OFFLOAD_FUNC_KQV },
+    { "Vcur",                       OFFLOAD_FUNC_KQV },
+
+    { "krot",                       OFFLOAD_FUNC_KQV },
+    { "qrot",                       OFFLOAD_FUNC_KQV },
+    { "kpass",                      OFFLOAD_FUNC_KQV },
+    { "qpass",                      OFFLOAD_FUNC_KQV },
+    { "krotated",                   OFFLOAD_FUNC_KQV },
+    { "qrotated",                   OFFLOAD_FUNC_KQV },
+
+    { "q",                          OFFLOAD_FUNC_KQV },
+    { "k",                          OFFLOAD_FUNC_KQV },
+    { "kq",                         OFFLOAD_FUNC_KQV },
+    { "kq_scaled",                  OFFLOAD_FUNC_KQV },
+    { "kq_scaled_alibi",            OFFLOAD_FUNC_KQV },
+    { "kq_masked",                  OFFLOAD_FUNC_KQV },
+    { "kq_soft_max",                OFFLOAD_FUNC_KQV },
+    { "kq_soft_max_ext",            OFFLOAD_FUNC_KQV },
+    { "v",                          OFFLOAD_FUNC_KQV },
+    { "kqv",                        OFFLOAD_FUNC_KQV },
+    { "kqv_merged",                 OFFLOAD_FUNC_KQV },
+    { "kqv_merged_cont",            OFFLOAD_FUNC_KQV },
+    { "kqv_wo",                     OFFLOAD_FUNC_KQV },
+    { "kqv_out",                    OFFLOAD_FUNC_KQV },
 
     { "ffn_inp",                    OFFLOAD_FUNC     },
     { "ffn_norm",                   OFFLOAD_FUNC     },
@@ -5678,15 +5623,15 @@ static struct ggml_cgraph * llama_build_graph(
             { OFFLOAD_FUNC_NOP, "CPU" },
             { OFFLOAD_FUNC_OUT, "CPU" },
 #ifdef GGML_USE_CUBLAS
-            { OFFLOAD_FUNC,     "GPU (CUDA)" },
-            { OFFLOAD_FUNC_KQ,  "GPU (CUDA) KQ" },
-            { OFFLOAD_FUNC_V,   "GPU (CUDA) V" },
-            { OFFLOAD_FUNC_NR,  "GPU (CUDA) NR" },
+            { OFFLOAD_FUNC,     "GPU (CUDA)"     },
+            { OFFLOAD_FUNC_FRC, "GPU (CUDA) FRC" },
+            { OFFLOAD_FUNC_KQV, "GPU (CUDA) KQV" },
+            { OFFLOAD_FUNC_NR,  "GPU (CUDA) NR"  },
             { OFFLOAD_FUNC_EMB, "GPU (CUDA) EMB" },
 #else
             { OFFLOAD_FUNC,     "CPU" },
-            { OFFLOAD_FUNC_KQ,  "CPU" },
-            { OFFLOAD_FUNC_V,   "CPU" },
+            { OFFLOAD_FUNC_FRC, "CPU" },
+            { OFFLOAD_FUNC_KQV, "CPU" },
             { OFFLOAD_FUNC_NR,  "CPU" },
             { OFFLOAD_FUNC_EMB, "CPU" },
 #endif // GGML_USE_CUBLAS
@@ -5719,18 +5664,23 @@ static struct ggml_cgraph * llama_build_graph(
                     }
                 }
                 break;
-            case OFFLOAD_FUNC_NR:
-                if (n_gpu_layers <= n_layer + 0) {
+            case OFFLOAD_FUNC_FRC:
+                if (!lctx.cparams.offload_kqv) {
                     func_e = OFFLOAD_FUNC_NOP;
-                }
-                break;
-            case OFFLOAD_FUNC_V:
-                if (n_gpu_layers <= n_layer + 1) {
+                } break;
+            case OFFLOAD_FUNC_KQV:
+                if (!lctx.cparams.offload_kqv) {
                     func_e = OFFLOAD_FUNC_NOP;
+                } else {
+                    if (n_gpu_layers < n_layer) {
+                        if (il < i_gpu_start) {
+                            func_e = OFFLOAD_FUNC_NOP;
+                        }
+                    }
                 }
                 break;
-            case OFFLOAD_FUNC_KQ:
-                if (n_gpu_layers <= n_layer + 2) {
+            case OFFLOAD_FUNC_NR:
+                if (n_gpu_layers <= n_layer + 0) {
                     func_e = OFFLOAD_FUNC_NOP;
                 }
                 break;
@@ -5755,8 +5705,8 @@ static struct ggml_cgraph * llama_build_graph(
             case OFFLOAD_FUNC_NOP:
             case OFFLOAD_FUNC_OUT: func = ggml_offload_nop; break;
             case OFFLOAD_FUNC:
-            case OFFLOAD_FUNC_KQ:
-            case OFFLOAD_FUNC_V:
+            case OFFLOAD_FUNC_KQV:
+            case OFFLOAD_FUNC_FRC:
             case OFFLOAD_FUNC_NR:
             case OFFLOAD_FUNC_EMB: func = ggml_offload_gpu; break;
             default: GGML_ASSERT(false);
@@ -5942,6 +5892,7 @@ static int llama_decode_internal(
     // after enough generations, the benefit from this heuristic disappears
     // if we start defragmenting the cache, the benefit from this will be more important
     kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
+    //kv_self.n = llama_kv_cache_cell_max(kv_self);
 
     //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
@@ -5992,7 +5943,7 @@ static int llama_decode_internal(
         n_threads = std::min(4, n_threads);
     }
 
-    const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
+    const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1;
     if (ggml_cpu_has_cublas() && fully_offloaded) {
         n_threads = 1;
     }
@@ -8821,10 +8772,12 @@ struct llama_context_params llama_context_default_params() {
         /*.yarn_beta_fast              =*/ 32.0f,
         /*.yarn_beta_slow              =*/ 1.0f,
         /*.yarn_orig_ctx               =*/ 0,
+        /*.type_k                      =*/ GGML_TYPE_F16,
+        /*.type_v                      =*/ GGML_TYPE_F16,
         /*.mul_mat_q                   =*/ true,
-        /*.f16_kv                      =*/ true,
         /*.logits_all                  =*/ false,
         /*.embedding                   =*/ false,
+        /*.offload_kqv                 =*/ true,
     };
 
     return result;
@@ -8941,6 +8894,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.yarn_beta_fast   = params.yarn_beta_fast;
     cparams.yarn_beta_slow   = params.yarn_beta_slow;
     cparams.mul_mat_q        = params.mul_mat_q;
+    cparams.offload_kqv      = params.offload_kqv;
 
     cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
     cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
@@ -8974,19 +8928,36 @@ struct llama_context * llama_new_context_with_model(
     ctx->rng = std::mt19937(params.seed);
     ctx->logits_all = params.logits_all;
 
-    ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
+    const ggml_type type_k = params.type_k;
+    const ggml_type type_v = params.type_v;
+
+    GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_k) == 0);
+    GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_v) == 0);
 
     // reserve memory for context buffers
     if (!hparams.vocab_only) {
-        if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers)) {
+        if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, type_k, type_v, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) {
             LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
             llama_free(ctx);
             return nullptr;
         }
 
         {
-            const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v);
-            LLAMA_LOG_INFO("%s: kv self size  = %7.2f MiB\n", __func__, memory_size / 1024.0 / 1024.0);
+            size_t memory_size_k = 0;
+            size_t memory_size_v = 0;
+
+            for (auto & k : ctx->kv_self.k_l) {
+                memory_size_k += ggml_nbytes(k);
+            }
+
+            for (auto & v : ctx->kv_self.v_l) {
+                memory_size_v += ggml_nbytes(v);
+            }
+
+            LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
+                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
+                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
+                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
         }
 
         // resized during inference
@@ -9057,8 +9028,12 @@ struct llama_context * llama_new_context_with_model(
             }
 
             size_t kv_vram_size = 0;
-            add_tensor(ctx->kv_self.k, kv_vram_size);
-            add_tensor(ctx->kv_self.v, kv_vram_size);
+            for (auto & k : ctx->kv_self.k_l) {
+                add_tensor(k, kv_vram_size);
+            }
+            for (auto & v : ctx->kv_self.v_l) {
+                add_tensor(v, kv_vram_size);
+            }
 
             size_t ctx_vram_size = alloc_size + kv_vram_size;
             size_t total_vram_size = model_vram_size + ctx_vram_size;
@@ -9528,37 +9503,45 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
         data_ctx->write(&kv_used,     sizeof(kv_used));
 
         if (kv_buf_size) {
-            const size_t elt_size = ggml_element_size(kv_self.k);
+            const size_t elt_size = ggml_element_size(kv_self.k_l[0]);
 
-            ggml_context * cpy_ctx = ggml_init({ 6*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
+            ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
             ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
 
-            ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
-            std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
-            kout3d->data = kout3d_data.data();
+            std::vector<std::vector<uint8_t>> kout2d_data(n_layer);
+            std::vector<std::vector<uint8_t>> vout2d_data(n_layer);
+
+            for (int il = 0; il < (int) n_layer; ++il) {
+                ggml_tensor * kout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head);
+                kout2d_data[il].resize(ggml_nbytes(kout2d));
+                kout2d->data = kout2d_data[il].data();
+
+                ggml_tensor * vout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd);
+                vout2d_data[il].resize(ggml_nbytes(vout2d));
+                vout2d->data = vout2d_data[il].data();
 
-            ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
-            std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0);
-            vout3d->data = vout3d_data.data();
+                ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il],
+                        n_embd, kv_head,
+                        elt_size*n_embd, 0);
 
-            ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
-                n_embd, kv_head, n_layer,
-                elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
+                ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il],
+                        kv_head, n_embd,
+                        elt_size*n_ctx, 0);
 
-            ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
-                kv_head, n_embd, n_layer,
-                elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
+                ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k2d, kout2d));
+                ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v2d, vout2d));
+            }
 
-            ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k3d, kout3d));
-            ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v3d, vout3d));
             ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
 
             ggml_free(cpy_ctx);
 
-            // our data is now in the kout3d_data and vout3d_data buffers
+            // our data is now in the kout2d_data and vout2d_data buffers
             // write them to file
-            data_ctx->write(kout3d_data.data(), kout3d_data.size());
-            data_ctx->write(vout3d_data.data(), vout3d_data.size());
+            for (uint32_t il = 0; il < n_layer; ++il) {
+                data_ctx->write(kout2d_data[il].data(), kout2d_data[il].size());
+                data_ctx->write(vout2d_data[il].data(), vout2d_data[il].size());
+            }
         }
 
         for (uint32_t i = 0; i < kv_size; ++i) {
@@ -9658,29 +9641,32 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
         if (kv_buf_size) {
             GGML_ASSERT(kv_self.buf.size == kv_buf_size);
 
-            const size_t elt_size = ggml_element_size(kv_self.k);
+            const size_t elt_size = ggml_element_size(kv_self.k_l[0]);
 
-            ggml_context * cpy_ctx = ggml_init({ 6*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
+            ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
             ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
 
-            ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
-            kin3d->data = (void *) inp;
-            inp += ggml_nbytes(kin3d);
+            for (int il = 0; il < n_layer; ++il) {
+                ggml_tensor * kin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head);
+                kin2d->data = (void *) inp;
+                inp += ggml_nbytes(kin2d);
 
-            ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
-            vin3d->data = (void *) inp;
-            inp += ggml_nbytes(vin3d);
+                ggml_tensor * vin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd);
+                vin2d->data = (void *) inp;
+                inp += ggml_nbytes(vin2d);
 
-            ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
-                n_embd, kv_head, n_layer,
-                elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
+                ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il],
+                    n_embd, kv_head,
+                    elt_size*n_embd, 0);
 
-            ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
-                kv_head, n_embd, n_layer,
-                elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
+                ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il],
+                    kv_head, n_embd,
+                    elt_size*n_ctx, 0);
+
+                ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin2d, k2d));
+                ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin2d, v2d));
+            }
 
-            ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin3d, k3d));
-            ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin3d, v3d));
             ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
 
             ggml_free(cpy_ctx);
diff --git a/llama.h b/llama.h
index 517245a3543004969306984ebcd8f147d90373c8..b1f5fca624d69def6989d0f0c1af7538448563ed 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -42,7 +42,7 @@
 #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
 
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 2
+#define LLAMA_SESSION_VERSION 3
 
 #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
 // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
@@ -211,11 +211,14 @@ extern "C" {
         float    yarn_beta_slow;   // YaRN high correction dim
         uint32_t yarn_orig_ctx;    // YaRN original context size
 
+        enum ggml_type type_k; // data type for K cache
+        enum ggml_type type_v; // data type for V cache
+
         // Keep the booleans together to avoid misalignment during copy-by-value.
-        bool mul_mat_q;  // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
-        bool f16_kv;     // use fp16 for KV cache, fp32 otherwise
-        bool logits_all; // the llama_eval() call computes all logits, not just the last one
-        bool embedding;  // embedding mode only
+        bool mul_mat_q;   // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
+        bool logits_all;  // the llama_eval() call computes all logits, not just the last one
+        bool embedding;   // embedding mode only
+        bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
     };
 
     // model quantization parameters