]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
dolly-v2 : minor formatting
authorGeorgi Gerganov <redacted>
Mon, 8 May 2023 15:03:47 +0000 (18:03 +0300)
committerGeorgi Gerganov <redacted>
Mon, 8 May 2023 15:03:47 +0000 (18:03 +0300)
examples/dolly-v2/main.cpp

index 5825e838907e7cc1dd9ff34c882475d3348dce2b..4eaa2fee8ace7e9a24a24e38b200a3eea524afe4 100644 (file)
 // default hparams (Dolly-V2 3B)
 struct dollyv2_hparams {
     int32_t n_vocab = 50254; // tokenizer.vocab_size
-    int32_t n_ctx   = 2048; // model.config.max_position_embeddings
-    int32_t n_embd  = 2560; // model.config.hidden_size
-    int32_t n_head  = 32; // model.config.num_attention_heads
-    int32_t n_layer = 32; // model.config.num_hidden_layers
-    int32_t n_rot   = 20; // rotary_pct[25%] * (n_embd / n_head)
+    int32_t n_ctx   = 2048;  // model.config.max_position_embeddings
+    int32_t n_embd  = 2560;  // model.config.hidden_size
+    int32_t n_head  = 32;    // model.config.num_attention_heads
+    int32_t n_layer = 32;    // model.config.num_hidden_layers
+    int32_t n_rot   = 20;    // rotary_pct[25%] * (n_embd / n_head)
     int32_t ftype   = GGML_FTYPE_MOSTLY_F16;
 };
 
 const std::string INSTRUCTION_KEY = "### Instruction:";
-const std::string RESPONSE_KEY = "### Response:";
-const std::string END_KEY = "### End";
-const std::string INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request.";
+const std::string RESPONSE_KEY    = "### Response:";
+const std::string END_KEY         = "### End";
+const std::string INTRO_BLURB     = "Below is an instruction that describes a task. Write a response that appropriately completes the request.";
 
 // dollyv2 prompt format
-std::string promptForGenerationFormat(const std::string& instruction) {
+std::string prompt_for_generation(const std::string& instruction) {
     return INTRO_BLURB + "\n\n" + INSTRUCTION_KEY + "\n" + instruction + "\n\n" + RESPONSE_KEY + "\n";
 }
 
@@ -672,7 +672,7 @@ int main(int argc, char ** argv) {
         }
     }
 
-    std::string prompt = promptForGenerationFormat(params.prompt);
+    const std::string prompt = prompt_for_generation(params.prompt);
 
     int64_t t_load_us = 0;
 
@@ -715,7 +715,7 @@ int main(int argc, char ** argv) {
     size_t mem_per_token = 0;
     dollyv2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
 
-    int32_t end_token = vocab.token_to_id["### End"];
+    const int32_t end_token = vocab.token_to_id["### End"];
 
     for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
         // predict
@@ -775,7 +775,6 @@ int main(int argc, char ** argv) {
         if (embd.back() == 0 || (end_token > 0 && embd.back() == end_token)) {
             break;
         }
-
     }
 
     // report timing