]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Tue, 19 Nov 2024 17:08:57 +0000 (19:08 +0200)
committerGeorgi Gerganov <redacted>
Wed, 20 Nov 2024 19:00:08 +0000 (21:00 +0200)
examples/talk-llama/llama.cpp
examples/talk-llama/llama.h

index 97eee26a577bccd357414c5c8ae6713f87c4c174..c51b36e66042e214553fb9d08e24516dc4c5e598 100644 (file)
@@ -179,6 +179,7 @@ enum llm_arch {
     LLM_ARCH_COMMAND_R,
     LLM_ARCH_DBRX,
     LLM_ARCH_OLMO,
+    LLM_ARCH_OLMO_1124,
     LLM_ARCH_OLMOE,
     LLM_ARCH_OPENELM,
     LLM_ARCH_ARCTIC,
@@ -232,6 +233,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_COMMAND_R,       "command-r"    },
     { LLM_ARCH_DBRX,            "dbrx"         },
     { LLM_ARCH_OLMO,            "olmo"         },
+    { LLM_ARCH_OLMO_1124,       "olmo_1124"    },
     { LLM_ARCH_OLMOE,           "olmoe"        },
     { LLM_ARCH_OPENELM,         "openelm"      },
     { LLM_ARCH_ARCTIC,          "arctic"       },
@@ -1207,6 +1209,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_OLMO_1124,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_OLMOE,
         {
@@ -2907,9 +2928,15 @@ struct llama_model {
     // for quantize-stats only
     std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name;
 
-    int64_t t_load_us = 0;
+    int64_t t_load_us  = 0;
     int64_t t_start_us = 0;
 
+    // total number of parameters in the model
+    uint64_t n_elements = 0;
+
+    // total size of all the tensors in the model in bytes
+    size_t  n_bytes     = 0;
+
     // keep track of loaded lora adapters
     std::set<struct llama_lora_adapter *> lora_adapters;
 
@@ -3454,21 +3481,13 @@ static bool llama_kv_cache_init(
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
         const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
 
-        const llama_model::buft_list_t * buft_list;
+        ggml_backend_buffer_type_t buft;
         if (offload) {
-            buft_list = model.dev_layer.at(i).buft_list;
+            auto * dev = model.dev_layer.at(i).dev;
+            buft = ggml_backend_dev_buffer_type(dev);
         } else {
-            buft_list = &model.cpu_buft_list;
+            buft = ggml_backend_cpu_buffer_type();
         }
-        ggml_backend_buffer_type_t buft = select_buft(*buft_list,
-            [&](ggml_context * ctx) {
-                ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
-                if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) {
-                    return k;
-                }
-                ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
-                return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type);
-            });
         ggml_context * ctx = ctx_for_buft(buft);
 
         if (!ctx) {
@@ -4275,8 +4294,8 @@ struct llama_model_loader {
     int n_tensors = 0;
     int n_created = 0;
 
-    int64_t n_elements = 0;
-    size_t  n_bytes    = 0;
+    uint64_t n_elements = 0;
+    size_t  n_bytes     = 0;
 
     bool use_mmap = false;
     bool check_tensors;
@@ -5344,6 +5363,11 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
     }
 }
 
+static void llm_load_stats(llama_model_loader & ml, llama_model & model) {
+    model.n_elements = ml.n_elements;
+    model.n_bytes = ml.n_bytes;
+}
+
 static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
     model.arch = ml.get_arch();
     if (model.arch == LLM_ARCH_UNKNOWN) {
@@ -5874,6 +5898,17 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_OLMO_1124:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                    case 16: model.type = e_model::MODEL_1B; break;
+                    case 32: model.type = e_model::MODEL_7B; break;
+                    case 40: model.type = e_model::MODEL_13B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_OLMOE:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -7254,7 +7289,7 @@ static llama_model::buft_list_t make_cpu_buft_list(llama_model & model) {
     auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
     auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
     auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
-        ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_cpu_get_extra_bufts");
+        ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
     if (ggml_backend_dev_get_extra_bufts_fn) {
         ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
         while (extra_bufts && *extra_bufts) {
@@ -7521,7 +7556,7 @@ static bool llm_load_tensors(
 
             // avoid using a host buffer when using mmap
             auto * buft_dev = ggml_backend_buft_get_device(buft);
-            if (ml.use_mmap && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
+            if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
                 auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
                 buft = ggml_backend_dev_buffer_type(cpu_dev);
             }
@@ -8556,6 +8591,31 @@ static bool llm_load_tensors(
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
+            case LLM_ARCH_OLMO_1124:
+                {
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = model.layers[i];
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+                    }
+                } break;
             case LLM_ARCH_OLMOE:
                 {
                     model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -9128,6 +9188,10 @@ static bool llm_load_tensors(
 
         // check if it is possible to use buffer_from_host_ptr with this buffer type
         ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
+        if (!dev) {
+            // FIXME: workaround for CPU backend buft having a NULL device
+            dev = ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0);
+        }
         ggml_backend_dev_props props;
         ggml_backend_dev_get_props(dev, &props);
         bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
@@ -9252,6 +9316,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
             throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
         }
 
+        llm_load_stats(ml, model);
         llm_load_print_meta(ml, model);
 
         if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
@@ -14416,6 +14481,130 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_olmo_1124() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            cur = inpL;
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur_rope", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_post_norm", il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_ffn(ctx0, lctx, ffn_inp,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                model.layers[il].ffn_post_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     // based on the build_qwen2moe() function, changes:
     //   * removed shared experts
     //   * removed bias
@@ -16608,6 +16797,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_olmo();
             } break;
+        case LLM_ARCH_OLMO_1124:
+            {
+                result = llm.build_olmo_1124();
+            } break;
         case LLM_ARCH_OLMOE:
             {
                 result = llm.build_olmoe();
@@ -18020,7 +18213,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
 
     // apply K-shift if needed
     if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
-        if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
+        if (!llama_kv_cache_can_shift(&lctx)) {
             GGML_ABORT("Deepseek2 does not support K-shift");
         }
 
@@ -18597,6 +18790,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     llama_model model;
     llm_load_arch(ml, model);
     llm_load_hparams(ml, model);
+    llm_load_stats(ml, model);
 
     struct quantize_state_internal qs(model, params);
 
@@ -19876,6 +20070,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_QWEN:
         case LLM_ARCH_QWEN2:
         case LLM_ARCH_QWEN2MOE:
+        case LLM_ARCH_OLMO_1124:
         case LLM_ARCH_OLMOE:
         case LLM_ARCH_PHI2:
         case LLM_ARCH_PHI3:
@@ -19949,19 +20144,11 @@ int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t bu
 }
 
 uint64_t llama_model_size(const struct llama_model * model) {
-    uint64_t size = 0;
-    for (const auto & it : model->tensors_by_name) {
-        size += ggml_nbytes(it.second);
-    }
-    return size;
+    return model->n_bytes;
 }
 
 uint64_t llama_model_n_params(const struct llama_model * model) {
-    uint64_t nparams = 0;
-    for (const auto & it : model->tensors_by_name) {
-        nparams += ggml_nelements(it.second);
-    }
-    return nparams;
+    return model->n_elements;
 }
 
 struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) {
@@ -20275,6 +20462,10 @@ void llama_kv_cache_update(struct llama_context * ctx) {
     llama_kv_cache_update_internal(*ctx);
 }
 
+bool llama_kv_cache_can_shift(struct llama_context * ctx) {
+    return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
+}
+
 // deprecated
 size_t llama_get_state_size(struct llama_context * ctx) {
     return llama_state_get_size(ctx);
@@ -22021,7 +22212,6 @@ const char * llama_print_system_info(void) {
     s += "FP16_VA = "     + std::to_string(ggml_cpu_has_fp16_va())     + " | ";
     s += "RISCV_VECT = "  + std::to_string(ggml_cpu_has_riscv_v())     + " | ";
     s += "WASM_SIMD = "   + std::to_string(ggml_cpu_has_wasm_simd())   + " | ";
-    s += "BLAS = "        + std::to_string(ggml_cpu_has_blas())        + " | ";
     s += "SSE3 = "        + std::to_string(ggml_cpu_has_sse3())        + " | ";
     s += "SSSE3 = "       + std::to_string(ggml_cpu_has_ssse3())       + " | ";
     s += "VSX = "         + std::to_string(ggml_cpu_has_vsx())         + " | ";
@@ -22067,28 +22257,6 @@ void llama_perf_context_reset(struct llama_context * ctx) {
     ctx->t_p_eval_us = ctx->n_p_eval = 0;
 }
 
-void llama_perf_dump_yaml(FILE * stream, const llama_context * ctx) {
-    fprintf(stream, "\n");
-    fprintf(stream, "###########\n");
-    fprintf(stream, "# Timings #\n");
-    fprintf(stream, "###########\n");
-    fprintf(stream, "\n");
-
-    fprintf(stream, "mst_eval: %.2f  # ms / token during generation\n",
-            1.0e-3 * ctx->t_eval_us / ctx->n_eval);
-    fprintf(stream, "mst_p_eval: %.2f  # ms / token during prompt processing\n",
-            1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
-    fprintf(stream, "n_eval: %d  # number of tokens generated (excluding the first one)\n", ctx->n_eval);
-    fprintf(stream, "n_p_eval: %d  # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
-    fprintf(stream, "t_eval_us: %" PRId64 "  # total microseconds spent generating tokens\n", ctx->t_eval_us);
-    fprintf(stream, "t_load_us: %" PRId64 "  # total microseconds spent loading the model\n", ctx->t_load_us);
-    fprintf(stream, "t_p_eval_us: %" PRId64 "  # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
-    fprintf(stream, "ts_eval: %.2f  # tokens / second during generation\n",
-            1.0e6 * ctx->n_eval / ctx->t_eval_us);
-    fprintf(stream, "ts_p_eval: %.2f  # tokens / second during prompt processing\n",
-            1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
-}
-
 // For internal test use
 const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
     struct llama_context * ctx
index 5e742642eec8de2ebdcbb3e756290829ff2623b5..90791d5f5ea126a53840030d9e95348739ba4360 100644 (file)
@@ -667,6 +667,9 @@ extern "C" {
     // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
     LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
 
+    // Check if the context supports KV cache shifting
+    LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
+
     //
     // State / sessions
     //
@@ -1244,8 +1247,6 @@ extern "C" {
     LLAMA_API void                           llama_perf_sampler_print(const struct llama_sampler * chain);
     LLAMA_API void                           llama_perf_sampler_reset(      struct llama_sampler * chain);
 
-    LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx);
-
 #ifdef __cplusplus
 }
 #endif