// default hparams (Dolly-V2 3B)
struct dollyv2_hparams {
int32_t n_vocab = 50254; // tokenizer.vocab_size
- int32_t n_ctx = 2048; // model.config.max_position_embeddings
- int32_t n_embd = 2560; // model.config.hidden_size
- int32_t n_head = 32; // model.config.num_attention_heads
- int32_t n_layer = 32; // model.config.num_hidden_layers
- int32_t n_rot = 20; // rotary_pct[25%] * (n_embd / n_head)
+ int32_t n_ctx = 2048; // model.config.max_position_embeddings
+ int32_t n_embd = 2560; // model.config.hidden_size
+ int32_t n_head = 32; // model.config.num_attention_heads
+ int32_t n_layer = 32; // model.config.num_hidden_layers
+ int32_t n_rot = 20; // rotary_pct[25%] * (n_embd / n_head)
int32_t ftype = GGML_FTYPE_MOSTLY_F16;
};
const std::string INSTRUCTION_KEY = "### Instruction:";
-const std::string RESPONSE_KEY = "### Response:";
-const std::string END_KEY = "### End";
-const std::string INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request.";
+const std::string RESPONSE_KEY = "### Response:";
+const std::string END_KEY = "### End";
+const std::string INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request.";
// dollyv2 prompt format
-std::string promptForGenerationFormat(const std::string& instruction) {
+std::string prompt_for_generation(const std::string& instruction) {
return INTRO_BLURB + "\n\n" + INSTRUCTION_KEY + "\n" + instruction + "\n\n" + RESPONSE_KEY + "\n";
}
}
}
- std::string prompt = promptForGenerationFormat(params.prompt);
+ const std::string prompt = prompt_for_generation(params.prompt);
int64_t t_load_us = 0;
size_t mem_per_token = 0;
dollyv2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
- int32_t end_token = vocab.token_to_id["### End"];
+ const int32_t end_token = vocab.token_to_id["### End"];
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
// predict
if (embd.back() == 0 || (end_token > 0 && embd.back() == end_token)) {
break;
}
-
}
// report timing