]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
sync : llama.cpp
authorGeorgi Gerganov <redacted>
Wed, 17 Jan 2024 19:23:33 +0000 (21:23 +0200)
committerGeorgi Gerganov <redacted>
Wed, 17 Jan 2024 19:23:33 +0000 (21:23 +0200)
examples/talk-llama/llama.cpp
examples/talk-llama/llama.h

index 7af38718c4130d1c329a1f48e5af4b3fc68655da..d28382f7d47b7ab9bfb7f44b37531890fad3feb0 100644 (file)
@@ -1393,6 +1393,9 @@ struct llama_cparams {
 
     bool mul_mat_q;
     bool offload_kqv;
+
+    ggml_backend_sched_eval_callback cb_eval;
+    void * cb_eval_user_data;
 };
 
 struct llama_layer {
@@ -6254,6 +6257,7 @@ static int llama_decode_internal(
     //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
     ggml_backend_sched_reset(lctx.sched);
+    ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
     ggml_cgraph * gf = llama_build_graph(lctx, batch);
 
@@ -7898,39 +7902,59 @@ static void llama_log_softmax(float * array, size_t size) {
     }
 }
 
+void llama_sample_apply_guidance(
+          struct llama_context * ctx,
+                         float * logits,
+                         float * logits_guidance,
+                         float   scale) {
+    GGML_ASSERT(ctx);
+
+    const auto t_start_sample_us = ggml_time_us();
+    const auto n_vocab = llama_n_vocab(llama_get_model(ctx));
+
+    llama_log_softmax(logits, n_vocab);
+    llama_log_softmax(logits_guidance, n_vocab);
+
+    for (int i = 0; i < n_vocab; ++i) {
+              auto & l = logits[i];
+        const auto & g = logits_guidance[i];
+
+        l = scale * (l - g) + g;
+    }
+
+    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+}
+
 void llama_sample_classifier_free_guidance(
           struct llama_context * ctx,
         llama_token_data_array * candidates,
           struct llama_context * guidance_ctx,
                          float   scale) {
-    int64_t t_start_sample_us = ggml_time_us();
-
     GGML_ASSERT(ctx);
+    int64_t t_start_sample_us;
 
-    auto n_vocab = llama_n_vocab(llama_get_model(ctx));
+    t_start_sample_us = ggml_time_us();
+    const size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
 
-    GGML_ASSERT(n_vocab == (int)candidates->size);
+    GGML_ASSERT(n_vocab == candidates->size);
     GGML_ASSERT(!candidates->sorted);
 
-    std::vector<float> logits_base;
-    logits_base.reserve(candidates->size);
-    for (size_t i = 0; i < candidates->size; ++i) {
-        logits_base.push_back(candidates->data[i].logit);
+    std::vector<float> logits_base(n_vocab);
+    for (size_t i = 0; i < n_vocab; ++i) {
+        logits_base[i] = candidates->data[i].logit;
     }
-    llama_log_softmax(logits_base.data(), candidates->size);
 
-    float* logits_guidance = llama_get_logits(guidance_ctx);
-    llama_log_softmax(logits_guidance, n_vocab);
+    float * logits_guidance = llama_get_logits(guidance_ctx);
 
-    for (int i = 0; i < n_vocab; ++i) {
-        float logit_guidance = logits_guidance[i];
-        float logit_base = logits_base[i];
-        candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance;
-    }
+    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    llama_sample_apply_guidance(ctx, logits_base.data(), logits_guidance, scale);
+    t_start_sample_us = ggml_time_us();
 
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    for (size_t i = 0; i < n_vocab; ++i) {
+        candidates->data[i].logit = logits_base[i];
     }
+
+    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
 }
 
 llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
@@ -8354,6 +8378,8 @@ struct quantize_state_internal {
     int n_k_quantized     = 0;
     int n_fallback        = 0;
 
+    bool has_imatrix      = false;
+
     quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params)
         : model(model)
         , params(params)
@@ -8455,7 +8481,12 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
         }
         else if (name == "token_embd.weight") new_type = GGML_TYPE_Q2_K;
     } else if (name.find("attn_v.weight") != std::string::npos) {
-        if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
+        if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
+            new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
+        }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
+            new_type = GGML_TYPE_Q4_K;
+        }
         else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
             new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
         }
@@ -8526,6 +8557,13 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
         else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
             new_type = GGML_TYPE_Q5_K;
         }
+        else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0)
+                && qs.has_imatrix && i_layer < n_layer/8) {
+            // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0.
+            // We only do it when an imatrix is provided because a) we want to make sure that one can always get the
+            // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix.
+            new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
+        }
         ++qs.i_feed_forward_w2;
     } else if (name.find("attn_output.weight") != std::string::npos) {
         if (arch != LLM_ARCH_FALCON) {
@@ -8559,7 +8597,8 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
     //}
     bool convert_incompatible_tensor = false;
     if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
-        new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) {
+        new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K ||
+        new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS) {
         int nx = tensor->ne[0];
         int ny = tensor->ne[1];
         if (nx % QK_K != 0) {
@@ -8571,6 +8610,8 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
     }
     if (convert_incompatible_tensor) {
         switch (new_type) {
+            case GGML_TYPE_IQ2_XXS:
+            case GGML_TYPE_IQ2_XS:
             case GGML_TYPE_Q2_K: new_type = GGML_TYPE_Q4_0; break;
             case GGML_TYPE_Q3_K: new_type = GGML_TYPE_Q4_1; break;
             case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
@@ -8646,6 +8687,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
         if (imatrix_data) {
             LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
+            qs.has_imatrix = true;
         }
     }
 
@@ -8705,8 +8747,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     // placeholder for the meta data
     ::zeros(fout, meta_size);
 
-    std::set<ggml_type> used_iq2;
-
     for (int i = 0; i < ml.n_tensors; ++i) {
         struct ggml_tensor * tensor = ml.get_tensor_meta(i);
 
@@ -8759,11 +8799,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         } else {
             const size_t nelements = ggml_nelements(tensor);
 
-            if ((new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_XS) && used_iq2.find(new_type) == used_iq2.end()) {
-                ggml_init_iq2_quantization(new_type);
-                used_iq2.insert(new_type);
-            }
-
             const float * imatrix = nullptr;
             if (imatrix_data) {
                 auto it = imatrix_data->find(tensor->name);
@@ -8889,10 +8924,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
     fout.close();
 
-    for (auto type : used_iq2) {
-        ggml_deinit_iq2_quantization(type);
-    }
-
     gguf_free(ctx_out);
 
     LLAMA_LOG_INFO("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
@@ -9238,6 +9269,8 @@ struct llama_context_params llama_context_default_params() {
         /*.yarn_beta_fast              =*/ 32.0f,
         /*.yarn_beta_slow              =*/ 1.0f,
         /*.yarn_orig_ctx               =*/ 0,
+        /*.cb_eval                     =*/ nullptr,
+        /*.cb_eval_user_data           =*/ nullptr,
         /*.type_k                      =*/ GGML_TYPE_F16,
         /*.type_v                      =*/ GGML_TYPE_F16,
         /*.mul_mat_q                   =*/ true,
@@ -9298,6 +9331,7 @@ void llama_backend_free(void) {
 #ifdef GGML_USE_MPI
     ggml_mpi_backend_free();
 #endif
+    ggml_quantize_free();
 }
 
 int64_t llama_time_us(void) {
@@ -9378,6 +9412,9 @@ struct llama_context * llama_new_context_with_model(
                                hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
                                                               hparams.n_ctx_train;
 
+    cparams.cb_eval           = params.cb_eval;
+    cparams.cb_eval_user_data = params.cb_eval_user_data;
+
     auto rope_scaling_type = params.rope_scaling_type;
     if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
         rope_scaling_type = hparams.rope_scaling_type_train;
index 79c8335b66bdfa2d77d91e03c3ae2f19d59632ad..e268d7a1d0cc9bf9dd8290fdbb6da2f2959589d5 100644 (file)
@@ -2,6 +2,7 @@
 #define LLAMA_H
 
 #include "ggml.h"
+#include "ggml-backend.h"
 #ifdef GGML_USE_CUBLAS
 #include "ggml-cuda.h"
 #define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
@@ -231,6 +232,9 @@ extern "C" {
         float    yarn_beta_slow;   // YaRN high correction dim
         uint32_t yarn_orig_ctx;    // YaRN original context size
 
+        ggml_backend_sched_eval_callback cb_eval;
+        void * cb_eval_user_data;
+
         enum ggml_type type_k; // data type for K cache
         enum ggml_type type_v; // data type for V cache
 
@@ -714,14 +718,21 @@ extern "C" {
                            float   penalty_present);
 
     /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
-    /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
-    /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
-    /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
-    LLAMA_API void llama_sample_classifier_free_guidance(
+    /// @param logits Logits extracted from the original generation context.
+    /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
+    /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
+    LLAMA_API void llama_sample_apply_guidance(
+              struct llama_context * ctx,
+                             float * logits,
+                             float * logits_guidance,
+                             float   scale);
+
+    LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance(
               struct llama_context * ctx,
             llama_token_data_array * candidates,
               struct llama_context * guidance_ctx,
-                             float   scale);
+                             float   scale),
+              "use llama_sample_apply_guidance() instead");
 
     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
     LLAMA_API void llama_sample_softmax(