]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Sun, 28 Jan 2024 17:44:10 +0000 (19:44 +0200)
committerGeorgi Gerganov <redacted>
Sun, 28 Jan 2024 17:44:10 +0000 (19:44 +0200)
examples/talk-llama/llama.cpp
examples/talk-llama/llama.h

index b03b67e169955cc684148155c566d6db3f069ba2..f7d054c577aea5ecdab886a1baee2320b50a3337 100644 (file)
 #  include "ggml-cuda.h"
 #elif defined(GGML_USE_CLBLAST)
 #  include "ggml-opencl.h"
+#elif defined(GGML_USE_VULKAN)
+#  include "ggml-vulkan.h"
+#elif defined(GGML_USE_SYCL)
+#  include "ggml-sycl.h"
 #endif
 
 #ifdef GGML_USE_METAL
@@ -52,6 +56,7 @@
 #include <algorithm>
 #include <array>
 #include <cassert>
+#include <cfloat>
 #include <cinttypes>
 #include <climits>
 #include <cmath>
@@ -196,6 +201,7 @@ enum llm_arch {
     LLM_ARCH_PHI2,
     LLM_ARCH_PLAMO,
     LLM_ARCH_CODESHELL,
+    LLM_ARCH_ORION,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -217,6 +223,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
     { LLM_ARCH_PHI2,            "phi2"      },
     { LLM_ARCH_PLAMO,           "plamo"     },
     { LLM_ARCH_CODESHELL,       "codeshell" },
+    { LLM_ARCH_ORION,           "orion"     },
 };
 
 enum llm_kv {
@@ -641,6 +648,25 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_ORION,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
 
     {
         LLM_ARCH_UNKNOWN,
@@ -1256,8 +1282,14 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer
     if (host_buffer) {
         buft = ggml_backend_cuda_host_buffer_type();
     }
+#elif defined(GGML_USE_SYCL)
+    buft = ggml_backend_sycl_host_buffer_type();
 #elif defined(GGML_USE_CPU_HBM)
     buft = ggml_backend_cpu_hbm_buffer_type();
+#elif defined(GGML_USE_VULKAN)
+    if (host_buffer) {
+        buft = ggml_backend_vk_host_buffer_type();
+    }
 #endif
 
     if (buft == nullptr) {
@@ -1275,6 +1307,10 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(int gpu) {
     buft = ggml_backend_metal_buffer_type();
 #elif defined(GGML_USE_CUBLAS)
     buft = ggml_backend_cuda_buffer_type(gpu);
+#elif defined(GGML_USE_VULKAN)
+    buft = ggml_backend_vk_buffer_type();
+#elif defined(GGML_USE_SYCL)
+    buft = ggml_backend_sycl_buffer_type(gpu);
 #elif defined(GGML_USE_CLBLAST)
     buft = ggml_backend_opencl_buffer_type();
 #endif
@@ -1332,6 +1368,7 @@ enum e_model {
     MODEL_7B,
     MODEL_8B,
     MODEL_13B,
+    MODEL_14B,
     MODEL_15B,
     MODEL_30B,
     MODEL_34B,
@@ -2683,6 +2720,7 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_7B:     return "7B";
         case MODEL_8B:     return "8B";
         case MODEL_13B:    return "13B";
+        case MODEL_14B:    return "14B";
         case MODEL_15B:    return "15B";
         case MODEL_30B:    return "30B";
         case MODEL_34B:    return "34B";
@@ -2950,7 +2988,15 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_ORION:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
 
+                switch (hparams.n_layer) {
+                    case 40: model.type = e_model::MODEL_14B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -3933,6 +3979,38 @@ static bool llm_load_tensors(
                         layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff});
                     }
                 } break;
+            case LLM_ARCH_ORION:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    {
+                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
+                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                    }
+                } break;
+
+
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -4563,6 +4641,126 @@ struct llm_build_context {
             ctx0 = nullptr;
         }
     }
+    struct ggml_cgraph * build_orion() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
+        cb(inpL, "inp_embd", -1);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
+        cb(inp_pos, "inp_pos", -1);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
+        cb(KQ_mask, "KQ_mask", -1);
+
+        // shift the entire K-cache if needed
+        if (do_rope_shift) {
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
+        }
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                // if (model.layers[il].bq) {
+                //     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                //     cb(Qcur, "Qcur", il);
+                // }
+
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                // if (model.layers[il].bk) {
+                //     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                //     cb(Kcur, "Kcur", il);
+                // }
+
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                // if (model.layers[il].bv) {
+                //     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                //     cb(Vcur, "Vcur", il);
+                // }
+
+                Qcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                cb(cur, "kqv_out", il);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, cur,
+                    model.layers[il].ffn_up,   NULL,
+                    model.layers[il].ffn_gate, NULL,
+                    model.layers[il].ffn_down, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+
 
     struct ggml_cgraph * build_llama() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@@ -6520,6 +6718,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_codeshell();
             } break;
+        case LLM_ARCH_ORION:
+            {
+                result = llm.build_orion();
+            } break;
         default:
             GGML_ASSERT(false);
     }
@@ -6652,7 +6854,7 @@ static int llama_decode_internal(
     }
 
     const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1;
-    if (ggml_cpu_has_cublas() && fully_offloaded) {
+    if ((ggml_cpu_has_cublas() || ggml_cpu_has_vulkan()) && fully_offloaded) {
         n_threads = 1;
     }
 
@@ -7946,6 +8148,11 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c
 }
 
 void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
+    // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
+    // if (k >= (int32_t)candidates->size) {
+    //     return;
+    // }
+
     const int64_t t_start_sample_us = ggml_time_us();
 
     k = std::max(k, (int) min_keep);
@@ -8054,21 +8261,56 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
         return;
     }
 
-    llama_sample_softmax(ctx, candidates);
-
     const int64_t t_start_sample_us = ggml_time_us();
 
-    float scale = candidates->data[0].p; // scale by max prob
-    size_t i = 1; // first token always matches
+    bool min_p_applied = false;
+
+    // if the candidates aren't sorted, try the unsorted implementation first
+    if (!candidates->sorted) {
+        std::vector<llama_token_data> filtered_tokens;
 
-    for (; i < candidates->size; ++i) {
-        if (candidates->data[i].p < p * scale && i >= min_keep) {
-            break; // prob too small
+        float max_logit = -FLT_MAX;
+        for (size_t i = 0; i < candidates->size; ++i) {
+            max_logit = std::max(max_logit, candidates->data[i].logit);
+        }
+        const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
+
+        for (size_t i = 0; i < candidates->size; ++i) {
+            if (candidates->data[i].logit >= min_logit) {
+                filtered_tokens.push_back(candidates->data[i]);
+            }
+        }
+
+        // if we have enough values the operation was a success
+        if (filtered_tokens.size() >= min_keep) {
+            memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
+            candidates->size = filtered_tokens.size();
+            min_p_applied = true;
         }
     }
 
-    // Resize the output vector to keep only the matching tokens
-    candidates->size = i;
+    // if the candidates are sorted or the unsorted implementation failed, use this implementation
+    if (!min_p_applied) {
+        // Sort the logits in descending order
+        if (!candidates->sorted) {
+            std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+                return a.logit > b.logit;
+            });
+            candidates->sorted = true;
+        }
+
+        const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
+        size_t i = 1; // first token always matches
+
+        for (; i < candidates->size; ++i) {
+            if (candidates->data[i].logit < min_logit && i >= min_keep) {
+                break; // prob too small
+            }
+        }
+
+        // Resize the output vector to keep only the matching tokens
+        candidates->size = i;
+    }
 
     if (ctx) {
         ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
@@ -9997,6 +10239,26 @@ struct llama_context * llama_new_context_with_model(
                 }
             }
         }
+#elif defined(GGML_USE_VULKAN)
+        if (model->n_gpu_layers > 0) {
+            ggml_backend_t backend = ggml_backend_vk_init();
+            if (backend == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__);
+                llama_free(ctx);
+                return nullptr;
+            }
+            ctx->backends.push_back(backend);
+        }
+#elif defined(GGML_USE_SYCL)
+        if (model->n_gpu_layers > 0) {
+            ggml_backend_t backend = ggml_backend_sycl_init(model->main_gpu);
+            if (backend == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d backend\n", __func__, model->main_gpu);
+                llama_free(ctx);
+                return nullptr;
+            }
+            ctx->backends.push_back(backend);
+        }
 #endif
         ctx->backend_cpu = ggml_backend_cpu_init();
         if (ctx->backend_cpu == nullptr) {
index 7b3634aa68505cfeaf77790aef66af74ae67b394..3e33072c68c17edd2a61f45d47ffece144788a7b 100644 (file)
@@ -6,6 +6,9 @@
 #ifdef GGML_USE_CUBLAS
 #include "ggml-cuda.h"
 #define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
+#elif defined(GGML_USE_SYCL)
+#include "ggml-sycl.h"
+#define LLAMA_MAX_DEVICES GGML_SYCL_MAX_DEVICES
 #else
 #define LLAMA_MAX_DEVICES 1
 #endif // GGML_USE_CUBLAS
@@ -46,7 +49,7 @@
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
 #define LLAMA_SESSION_VERSION 4
 
-#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
+#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL)
 // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
 #define LLAMA_SUPPORTS_GPU_OFFLOAD
 #endif