#include "ggml.h"
+#include "ggml-alloc.h"
#include "common.h"
#include "llama.h"
#include <unordered_map>
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
-static const float rms_norm_eps = 1e-5f;
-
struct random_normal_distribution {
std::mt19937 gen;
std::normal_distribution<float> rd;
return rnd->rd(rnd->gen);
}
-void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
- struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
-
- if (plan.work_size > 0) {
- buf.resize(plan.work_size);
- plan.work_data = buf.data();
- }
-
- ggml_graph_compute(graph, &plan);
-}
-
struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
float scale = 1.0f; // xavier
switch (tensor->n_dims) {
return tensor;
}
-struct llama_vocab {
- using id = int32_t;
- using token = std::string;
- using ttype = llama_token_type;
-
- struct token_data {
- token text;
- float score;
- ttype type;
- };
-
- std::unordered_map<token, id> token_to_id;
- std::vector<token_data> id_to_token;
-};
-
struct my_llama_hparams {
uint32_t n_vocab = 32000;
- uint32_t n_ctx = 512; // this is provided as user input?
+ uint32_t n_ctx = 512;
uint32_t n_embd = 4096;
- uint32_t n_mult = 4;
uint32_t n_head = 32;
uint32_t n_layer = 32;
uint32_t n_rot = 64;
+ uint32_t n_ff = 11008;
+
+ // float f_norm_eps = 1e-5; // falcon
+ float f_norm_rms_eps = 1e-5; // llama
+
+ float rope_freq_base = 10000.0f;
+ float rope_freq_scale = 1.0f;
bool operator!=(const my_llama_hparams& other) const {
return memcmp(this, &other, sizeof(my_llama_hparams));
struct ggml_tensor * w3;
};
-struct my_llama_kv_cache {
- struct ggml_context * ctx = NULL;
-
- struct ggml_tensor * k;
- struct ggml_tensor * v;
-
- // llama_ctx_buffer buf;
-
- int n; // number of tokens currently in the cache
-};
-
struct my_llama_model {
struct ggml_context * ctx = NULL;
uint32_t train_tokens = 0;
};
-uint32_t get_n_ff(const struct my_llama_hparams* hparams) {
- const uint32_t n_ff = ((2*(4*hparams->n_embd)/3 + hparams->n_mult - 1)/hparams->n_mult)*hparams->n_mult;
- return n_ff;
-}
+// gguf constants
+const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type";
+const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam";
+const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs";
+const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version";
+const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count";
+const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count";
+const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count";
+const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized";
+const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss";
+const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss";
+const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count";
+const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count";
+const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss";
+const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step";
+const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j";
+const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k";
+const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end";
+const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count";
+
+const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments";
+const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments";
+const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values";
+
+const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters";
+const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters";
+const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients";
+const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients";
+const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction";
+const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values";
+const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha";
+const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys";
+const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
+const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
+
+const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
+const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
+const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
+const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
+
+// gguf constants (sync with gguf.py)
+
+const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture";
+const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type";
+
+const char * LLM_KV_CONTEXT_LENGTH = "%s.context_length";
+const char * LLM_KV_EMBEDDING_LENGTH = "%s.embedding_length";
+const char * LLM_KV_BLOCK_COUNT = "%s.block_count";
+const char * LLM_KV_FEED_FORWARD_LENGTH = "%s.feed_forward_length";
+const char * LLM_KV_ATTENTION_HEAD_COUNT = "%s.attention.head_count";
+const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon";
+const char * LLM_KV_ROPE_DIMENSION_COUNT = "%s.rope.dimension_count";
+const char * LLM_KV_ROPE_FREQ_BASE = "%s.rope.freq_base"; // TODO load in llama.cpp
+const char * LLM_KV_ROPE_SCALE_LINEAR = "%s.rope.scale_linear";
+
+const char * LLM_KV_TOKENIZER_MODEL = "tokenizer.ggml.model";
+const char * LLM_KV_TOKENIZER_LIST = "tokenizer.ggml.tokens";
+const char * LLM_KV_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type";
+const char * LLM_KV_TOKENIZER_SCORES = "tokenizer.ggml.scores";
+const char * LLM_KV_TOKENIZER_MERGES = "tokenizer.ggml.merges";
+const char * LLM_KV_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id";
+const char * LLM_KV_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id";
+const char * LLM_KV_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id";
+const char * LLM_KV_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id";
+const char * LLM_KV_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id";
+
+const char * LLM_TENSOR_TOKEN_EMBD = "token_embd";
+const char * LLM_TENSOR_OUTPUT_NORM = "output_norm";
+const char * LLM_TENSOR_OUTPUT = "output";
+const char * LLM_TENSOR_ATTN_NORM = "blk.%d.attn_norm";
+const char * LLM_TENSOR_ATTN_Q = "blk.%d.attn_q";
+const char * LLM_TENSOR_ATTN_K = "blk.%d.attn_k";
+const char * LLM_TENSOR_ATTN_V = "blk.%d.attn_v";
+const char * LLM_TENSOR_ATTN_OUT = "blk.%d.attn_output";
+const char * LLM_TENSOR_FFN_NORM = "blk.%d.ffn_norm";
+const char * LLM_TENSOR_FFN_GATE = "blk.%d.ffn_gate";
+const char * LLM_TENSOR_FFN_DOWN = "blk.%d.ffn_down";
+const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up";
void print_params(struct my_llama_hparams * params) {
printf("%s: n_vocab: %d\n", __func__, params->n_vocab);
printf("%s: n_ctx: %d\n", __func__, params->n_ctx);
printf("%s: n_embd: %d\n", __func__, params->n_embd);
- printf("%s: n_mult: %d\n", __func__, params->n_mult);
printf("%s: n_head: %d\n", __func__, params->n_head);
- printf("%s: n_ff: %d\n", __func__, get_n_ff(params));
+ printf("%s: n_ff: %d\n", __func__, params->n_ff);
printf("%s: n_layer: %d\n", __func__, params->n_layer);
printf("%s: n_rot: %d\n", __func__, params->n_rot);
}
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_vocab = hparams.n_vocab;
-
- const uint32_t n_ff = get_n_ff(&hparams);
+ const uint32_t n_ff = hparams.n_ff;
struct ggml_context * ctx = model->ctx;
model->train_samples = 0;
model->train_tokens = 0;
+ std::vector<char> tn_buf;
+ tn_buf.resize(GGML_MAX_NAME);
+ auto tn = [&tn_buf](const char * key) -> const char * {
+ snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key);
+ return tn_buf.data();
+ };
+ auto tni = [&tn_buf](const char * key, int bid) -> const char * {
+ snprintf(tn_buf.data(), tn_buf.size(), key, bid);
+ std::string s = tn_buf.data();
+ snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str());
+ return tn_buf.data();
+ };
+
model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
- ggml_set_name(model->tok_embeddings, "tok_embeddings.weight");
- ggml_set_name(model->norm, "norm.weight");
- ggml_set_name(model->output, "output.weight");
+ ggml_set_name(model->tok_embeddings, tn(LLM_TENSOR_TOKEN_EMBD));
+ ggml_set_name(model->norm, tn(LLM_TENSOR_OUTPUT_NORM));
+ ggml_set_name(model->output, tn(LLM_TENSOR_OUTPUT));
model->layers.resize(n_layer);
for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = model->layers[i];
- std::string layers_i = "layers." + std::to_string(i);
-
layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
- ggml_set_name(layer.attention_norm, (layers_i + ".attention_norm.weight").c_str());
+ ggml_set_name(layer.attention_norm, tni(LLM_TENSOR_ATTN_NORM, i));
- ggml_set_name(layer.wq, (layers_i + ".attention.wq.weight").c_str());
- ggml_set_name(layer.wk, (layers_i + ".attention.wk.weight").c_str());
- ggml_set_name(layer.wv, (layers_i + ".attention.wv.weight").c_str());
- ggml_set_name(layer.wo, (layers_i + ".attention.wo.weight").c_str());
+ ggml_set_name(layer.wq, tni(LLM_TENSOR_ATTN_Q, i));
+ ggml_set_name(layer.wk, tni(LLM_TENSOR_ATTN_K, i));
+ ggml_set_name(layer.wv, tni(LLM_TENSOR_ATTN_V, i));
+ ggml_set_name(layer.wo, tni(LLM_TENSOR_ATTN_OUT, i));
- ggml_set_name(layer.ffn_norm, (layers_i + ".ffn_norm.weight").c_str());
+ ggml_set_name(layer.ffn_norm, tni(LLM_TENSOR_FFN_NORM, i));
- ggml_format_name(layer.w1, "%s.feed_forward.w1.weight", layers_i.c_str());
- ggml_format_name(layer.w2, "%s.feed_forward.w2.weight", layers_i.c_str());
- ggml_format_name(layer.w3, "%s.feed_forward.w3.weight", layers_i.c_str());
+ ggml_set_name(layer.w1, tni(LLM_TENSOR_FFN_GATE, i));
+ ggml_set_name(layer.w2, tni(LLM_TENSOR_FFN_DOWN, i));
+ ggml_set_name(layer.w3, tni(LLM_TENSOR_FFN_UP, i));
}
}
}
}
-bool init_kv_cache(struct my_llama_kv_cache* cache, struct my_llama_model * model, int n_batch) {
- const auto & hparams = model->hparams;
-
- const uint32_t n_ctx = hparams.n_ctx;
- const uint32_t n_embd = hparams.n_embd;
- const uint32_t n_layer = hparams.n_layer;
-
- const int64_t n_mem = n_layer*n_ctx*n_batch;
- const int64_t n_elements = n_embd*n_mem;
-
- // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
-
- // struct ggml_init_params params;
- // params.mem_size = cache.buf.size;
- // params.mem_buffer = cache.buf.addr;
- // params.no_alloc = false;
- if (!cache->ctx) {
- struct ggml_init_params params;
- params.mem_size = 2u*n_elements*ggml_type_size(GGML_TYPE_F32) + 2u*1024*1024;
- params.mem_buffer = NULL;
- params.no_alloc = false;
-
- cache->ctx = ggml_init(params);
-
- if (!cache->ctx) {
- fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
- return false;
- }
- }
-
- cache->k = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements);
- cache->v = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements);
-
- return true;
-}
-
-struct ggml_tensor * forward(
- struct my_llama_model * model,
- struct my_llama_kv_cache * cache,
- struct ggml_context * ctx0,
- struct ggml_cgraph * gf,
- struct ggml_tensor * tokens_input,
- const int n_tokens,
- const int n_past) {
-
- const int N = n_tokens;
-
- struct my_llama_kv_cache& kv_self = *cache;
- const auto & hparams = model->hparams;
- const int n_ctx = hparams.n_ctx;
- const int n_embd = hparams.n_embd;
- const int n_layer = hparams.n_layer;
- const int n_head = hparams.n_head;
- const int n_rot = hparams.n_rot;
-
- struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
- memcpy(tokens->data, tokens_input->data, N*ggml_element_size(tokens));
-
- struct ggml_tensor * kc = kv_self.k;
- struct ggml_tensor * vc = kv_self.v;
-
- // inpL shape [n_embd,N,1,1]
- struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
- for (int il = 0; il < n_layer; ++il) {
- struct ggml_tensor * inpSA = inpL;
-
- struct ggml_tensor * cur;
-
- // lctx.use_buf(ctx0, 0);
-
- // norm
- {
- // cur shape [n_embd,N,1,1]
- cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
-
- // cur = attention_norm*cur
- cur = ggml_mul(ctx0,
- ggml_repeat(ctx0, model->layers[il].attention_norm, cur),
- cur);
- }
-
- // self-attention
- {
- // compute Q and K and RoPE them
- // wq shape [n_embd, n_embd, 1, 1]
- // wk shape [n_embd, n_embd, 1, 1]
- // Qcur shape [n_embd/n_head, n_head, N, 1]
- // Kcur shape [n_embd/n_head, n_head, N, 1]
- struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
- struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
-
- // store key and value to memory
- {
- // compute the transposed [N, n_embd] V matrix
- // wv shape [n_embd, n_embd, 1, 1]
- // Vcur shape [n_embd, N, 1, 1]
- struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wv, cur), n_embd, N)));
-
- // kv_self.k shape [n_embd * n_ctx * n_layer, 1]
- // kv_self.v shape [n_embd * n_ctx * n_layer, 1]
- // k shape [n_embd * N, 1] == kv_self.k[:,n_past:n_past+N,il,0]
- // v shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0]
-
- /* {
- struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
- struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
- ( n_ctx)*ggml_element_size(kv_self.v),
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
-
- // important: storing RoPE-ed version of K in the KV cache!
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
- } //*/
-
- kc = ggml_set_1d_inplace(ctx0, kc, ggml_reshape_1d(ctx0, Kcur, n_embd*N), (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
- vc = ggml_set_2d_inplace(ctx0, vc, Vcur, ( n_ctx)*ggml_element_size(kv_self.v),
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
- }
-
- // Qcur shape [n_embd/n_head, n_head, N, 1]
- // Q shape [n_embd/n_head, N, n_head, 1]
- struct ggml_tensor * Q =
- ggml_permute(ctx0,
- Qcur,
- 0, 2, 1, 3);
-
- // kv_self.k shape [n_embd * n_ctx * n_layer, 1]
- // K shape [n_embd/n_head, n_past + N, n_head, 1]
- struct ggml_tensor * K =
- ggml_permute(ctx0,
- ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, kc, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kc)*n_embd),
- n_embd/n_head, n_head, n_past + N),
- 0, 2, 1, 3);
-
- // K * Q
- // KQ shape [n_past + N, N, n_head, 1]
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-
- // KQ_scaled = KQ / sqrt(n_embd/n_head)
- // KQ_scaled shape [n_past + N, N, n_head, 1]
- struct ggml_tensor * KQ_scaled =
- ggml_scale(ctx0,
- KQ,
- ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
-
- // KQ_masked = mask_past(KQ_scaled)
- // KQ_masked shape [n_past + N, N, n_head, 1]
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
-
- // KQ = soft_max(KQ_masked)
- // KQ_soft_max shape [n_past + N, N, n_head, 1]
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
-
- // split cached V into n_head heads
- //// V shape [n_past + N, n_embd/n_head, n_head, 1]
- // V shape [n_past + N, n_embd/n_head, n_head, 1] == kv_self.v[:,:(n_past+N),il,1]
- struct ggml_tensor * V =
- ggml_view_3d(ctx0, vc,
- n_past + N, n_embd/n_head, n_head,
- n_ctx*ggml_element_size(vc),
- n_ctx*ggml_element_size(vc)*n_embd/n_head,
- il*n_ctx*ggml_element_size(vc)*n_embd);
-
- // KQV shape [n_embd/n_head, N, n_head, 1]
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
-
- // KQV_merged = KQV.permute(0, 2, 1, 3)
- // KQV_merged shape [n_embd/n_head, n_head, N, 1]
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- // KQV_merged shape
-
- // cur = KQV_merged.contiguous().view(n_embd, N)
- // cur shape [n_embd,N,1,1]
- cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N);
- // cur = ggml_cpy(ctx0,
- // KQV_merged,
- // ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
-
- // projection (no bias)
- // cur shape [n_embd,N,1,1]
- cur = ggml_mul_mat(ctx0,
- model->layers[il].wo,
- cur);
- }
-
- // lctx.use_buf(ctx0, 1);
-
- // inpFF shape [n_embd,N,1,1]
- struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
-
- // feed-forward network
- {
- // norm
- {
- // cur shape [n_embd,N,1,1]
- cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
-
- // cur = ffn_norm*cur
- // cur shape [n_embd,N,1,1]
- cur = ggml_mul(ctx0,
- ggml_repeat(ctx0, model->layers[il].ffn_norm, cur),
- cur);
- }
-
- // tmp shape [n_ff,N,1,1]
- struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
- model->layers[il].w3,
- cur);
-
- // cur shape [n_ff,N,1,1]
- cur = ggml_mul_mat(ctx0,
- model->layers[il].w1,
- cur);
-
- // SILU activation
- // cur shape [n_ff,N,1,1]
- cur = ggml_silu(ctx0, cur);
-
- // cur shape [n_ff,N,1,1]
- cur = ggml_mul(ctx0, cur, tmp);
-
- // cur shape [n_embd,N,1,1]
- cur = ggml_mul_mat(ctx0,
- model->layers[il].w2,
- cur);
- }
-
- // cur shape [n_embd,N,1,1]
- cur = ggml_add(ctx0, cur, inpFF);
-
- // input for next layer
- // inpL shape [n_embd,N,1,1]
- inpL = cur;
- }
-
- // norm
- {
-
- // inpL shape [n_embd,N,1,1]
- inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
-
- // inpL = norm*inpL
- // inpL shape [n_embd,N,1,1]
- inpL = ggml_mul(ctx0,
- ggml_repeat(ctx0, model->norm, inpL),
- inpL);
-
- //embeddings = inpL;
- }
-
- // lm_head
- // inpL shape [n_vocab,N,1,1]
- inpL = ggml_mul_mat(ctx0, model->output, inpL);
-
- // run the computation
- ggml_build_forward_expand(gf, inpL);
-
- return inpL;
-}
-
void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
GGML_ASSERT(tensor->n_dims == 1);
GGML_ASSERT(tensor->ne[0] == ne0);
GGML_ASSERT(tensor->ne[3] == ne3);
}
-struct ggml_tensor * forward_batch(
- struct my_llama_model * model,
- struct my_llama_kv_cache * cache,
- struct ggml_context * ctx0,
- struct ggml_cgraph * gf,
- struct ggml_tensor * tokens_input,
- const int n_tokens,
- const int n_past,
- const int n_batch) {
-
- const int N = n_tokens;
-
- struct my_llama_kv_cache& kv_self = *cache;
- const auto & hparams = model->hparams;
- const int n_ctx = hparams.n_ctx;
- const int n_vocab = hparams.n_vocab;
- const int n_embd = hparams.n_embd;
- const int n_layer = hparams.n_layer;
- const int n_head = hparams.n_head;
- const int n_rot = hparams.n_rot;
- const int n_ff = get_n_ff(&hparams);
-
- struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch);
- memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch);
-
- struct ggml_tensor * kc = kv_self.k;
- struct ggml_tensor * vc = kv_self.v;
-
- // inpL shape [n_embd,N*n_batch,1]
- struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
- assert_shape_2d(inpL, n_embd, N*n_batch);
- for (int il = 0; il < n_layer; ++il) {
- struct ggml_tensor * inpSA = inpL;
-
- struct ggml_tensor * cur;
-
- // lctx.use_buf(ctx0, 0);
-
- // norm
- {
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
- assert_shape_2d(cur, n_embd, N*n_batch);
-
- // cur = attention_norm*cur
- cur = ggml_mul(ctx0,
- ggml_repeat(ctx0, model->layers[il].attention_norm, cur),
- cur);
- assert_shape_2d(cur, n_embd, N*n_batch);
- }
-
- // self-attention
- {
- // compute Q and K and RoPE them
- // wq shape [n_embd, n_embd, 1, 1]
- // wk shape [n_embd, n_embd, 1, 1]
- // Qcur shape [n_embd/n_head, n_head, N, n_batch]
- // Kcur shape [n_embd/n_head, n_head, N, n_batch]
- struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
- struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
- assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
- assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
-
- // store key and value to memory
- {
- // compute the transposed [N, n_embd] V matrix
- // wv shape [n_embd, n_embd, 1, 1]
- // Vcur shape [N, n_embd, n_batch, 1]
- struct ggml_tensor * Vcur = ggml_cont(ctx0,
- ggml_permute(ctx0,
- ggml_reshape_3d(ctx0,
- ggml_mul_mat(ctx0,
- model->layers[il].wv,
- cur),
- n_embd, N, n_batch),
- 1, 0, 2, 3));
- assert_shape_3d(Vcur, N, n_embd, n_batch);
-
- // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer]
- // kv_self.v shape [n_ctx * n_embd * n_batch * n_layer]
- // k shape [n_embd * N, n_batch] == kv_self.k[:,n_past:n_past+N,:,il]
- // v shape [N, n_embd, n_batch, 1] == kv_self.v[:,n_past:n_past+N,:,il]
-
- /* {
- struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
- struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
- ( n_ctx)*ggml_element_size(kv_self.v),
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
-
- // important: storing RoPE-ed version of K in the KV cache!
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
- } //*/
-
- kc = ggml_set_2d_inplace(ctx0, kc,
- ggml_reshape_2d(ctx0, Kcur, n_embd*N, n_batch),
- ggml_element_size(kc)*n_embd*n_ctx,
- (ggml_element_size(kc)*n_embd)*(il*n_batch*n_ctx + n_past));
- vc = ggml_set_2d_inplace(ctx0, vc,
- ggml_reshape_2d(ctx0, Vcur, N*n_embd, n_batch),
- ggml_element_size(vc)*n_ctx*n_embd,
- ggml_element_size(vc)*(n_past + il*n_embd*n_batch*n_ctx));
-
- assert_shape_1d(kc, n_embd * n_ctx * n_batch * n_layer);
- assert_shape_1d(vc, n_embd * n_ctx * n_batch * n_layer);
- }
-
- // Qcur shape [n_embd/n_head, n_head, N, n_batch]
- // Q shape [n_embd/n_head, N, n_head, n_batch]
- struct ggml_tensor * Q =
- ggml_permute(ctx0,
- Qcur,
- 0, 2, 1, 3);
- assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch);
-
- // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer]
- // K shape [n_embd/n_head, n_past + N, n_head, n_batch]
- struct ggml_tensor * K =
- ggml_permute(ctx0,
- ggml_reshape_4d(ctx0,
- ggml_view_3d(ctx0,
- kc,
- n_embd,
- (n_past + N),
- n_batch,
- n_embd*ggml_element_size(kc),
- n_ctx*n_embd*ggml_element_size(kc),
- il*n_batch*n_ctx*n_embd*ggml_element_size(kc)),
- n_embd/n_head, n_head, n_past + N, n_batch),
- 0, 2, 1, 3);
- assert_shape_4d(K, n_embd/n_head, n_past + N, n_head, n_batch);
-
- // K * Q
- // KQ shape [n_past + N, N, n_head, n_batch]
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
- assert_shape_4d(KQ, n_past + N, N, n_head, n_batch);
-
- // KQ_scaled = KQ / sqrt(n_embd/n_head)
- // KQ_scaled shape [n_past + N, N, n_head, n_batch]
- struct ggml_tensor * KQ_scaled =
- ggml_scale_inplace(ctx0,
- KQ,
- ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
- assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch);
-
- // KQ_masked = mask_past(KQ_scaled)
- // KQ_masked shape [n_past + N, N, n_head, n_batch]
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
- assert_shape_4d(KQ_masked, n_past + N, N, n_head, n_batch);
-
- // KQ = soft_max(KQ_masked)
- // KQ_soft_max shape [n_past + N, N, n_head, n_batch]
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
- assert_shape_4d(KQ_soft_max, n_past + N, N, n_head, n_batch);
-
- // split cached V into n_head heads
- // kv_self.v shape [n_ctx * n_embd * n_batch * n_layer]
- // V shape [n_past + N, n_embd/n_head, n_head, n_batch] == kv_self.v[:(n_past+N),:,:,il]
- struct ggml_tensor * V =
- ggml_view_4d(ctx0, vc,
- n_past + N, n_embd/n_head, n_head, n_batch,
- ggml_element_size(vc)*n_ctx,
- ggml_element_size(vc)*n_ctx*n_embd/n_head,
- ggml_element_size(vc)*n_ctx*n_embd,
- il*n_batch*n_ctx*n_embd*ggml_element_size(vc));
- assert_shape_4d(V, n_past + N, n_embd/n_head, n_head, n_batch);
-
- // KQV shape [n_embd/n_head, N, n_head, n_batch]
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
- assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch);
-
- // KQV_merged = KQV.permute(0, 2, 1, 3)
- // KQV_merged shape [n_embd/n_head, n_head, N, n_batch]
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch);
- // KQV_merged shape
-
- // cur = KQV_merged.contiguous().view(n_embd, N)
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch);
- assert_shape_2d(cur, n_embd, N*n_batch);
- // cur = ggml_cpy(ctx0,
- // KQV_merged,
- // ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
-
- // projection (no bias)
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_mul_mat(ctx0,
- model->layers[il].wo,
- cur);
- assert_shape_2d(cur, n_embd, N*n_batch);
- }
-
- // lctx.use_buf(ctx0, 1);
+static size_t hash(void * p) {
+ return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
+}
- // inpFF shape [n_embd,N*n_batch,1,1]
- struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA);
- assert_shape_2d(inpFF, n_embd, N*n_batch);
+static size_t hash_find(void * hash_table[], void * p) {
+ size_t h = hash(p);
- // feed-forward network
- {
- // norm
- {
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
- assert_shape_2d(cur, n_embd, N*n_batch);
-
- // cur = ffn_norm*cur
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_mul(ctx0,
- ggml_repeat(ctx0, model->layers[il].ffn_norm, cur),
- cur);
- assert_shape_2d(cur, n_embd, N*n_batch);
- }
-
- // tmp shape [n_ff,N*n_batch,1,1]
- struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
- model->layers[il].w3,
- cur);
- assert_shape_2d(tmp, n_ff, N*n_batch);
-
- // cur shape [n_ff,N*n_batch,1,1]
- cur = ggml_mul_mat(ctx0,
- model->layers[il].w1,
- cur);
- assert_shape_2d(cur, n_ff, N*n_batch);
-
- // SILU activation
- // cur shape [n_ff,N*n_batch,1,1]
- cur = ggml_silu(ctx0, cur);
- assert_shape_2d(cur, n_ff, N*n_batch);
-
- // cur shape [n_ff,N*n_batch,1,1]
- cur = ggml_mul(ctx0, cur, tmp);
- assert_shape_2d(cur, n_ff, N*n_batch);
-
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_mul_mat(ctx0,
- model->layers[il].w2,
- cur);
- assert_shape_2d(cur, n_embd, N*n_batch);
+ // linear probing
+ size_t i = h;
+ while (hash_table[i] != NULL && hash_table[i] != p) {
+ i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
+ if (i == h) {
+ // visited all hash table entries -> not found
+ return GGML_GRAPH_HASHTABLE_SIZE;
}
-
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_add_inplace(ctx0, cur, inpFF);
- assert_shape_2d(cur, n_embd, N*n_batch);
-
- // input for next layer
- // inpL shape [n_embd,N*n_batch,1,1]
- inpL = cur;
- assert_shape_2d(inpL, n_embd, N*n_batch);
}
+ return i;
+}
- // norm
- {
-
- // inpL shape [n_embd,N*n_batch,1,1]
- inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
- assert_shape_2d(inpL, n_embd, N*n_batch);
-
- // inpL = norm*inpL
- // inpL shape [n_embd,N*n_batch,1,1]
- inpL = ggml_mul(ctx0,
- ggml_repeat(ctx0, model->norm, inpL),
- inpL);
-
- assert_shape_2d(inpL, n_embd, N*n_batch);
+static bool hash_insert(void * hash_table[], void * p) {
+ //size_t h = hash(p);
+ size_t i = hash_find(hash_table, p);
- //embeddings = inpL;
- }
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
- // lm_head
- // inpL shape [n_vocab,N*n_batch,1,1]
- inpL = ggml_mul_mat(ctx0, model->output, inpL);
- assert_shape_2d(inpL, n_vocab, N*n_batch);
-
- {
- // inpL shape [n_vocab,N,n_batch,1]
- inpL = ggml_reshape_3d(ctx0,
- inpL,
- n_vocab, N, n_batch);
- assert_shape_3d(inpL, n_vocab, N, n_batch);
+ if (hash_table[i] == p) {
+ return true;
}
- // run the computation
- ggml_build_forward_expand(gf, inpL);
-
- return inpL;
+ // insert
+ GGML_ASSERT(hash_table[i] == NULL);
+ hash_table[i] = p;
+ return false;
}
-struct ggml_tensor * forward_batch_wo_cache(
- struct my_llama_model * model,
- struct ggml_context * ctx0,
- struct ggml_cgraph * gf,
- struct ggml_tensor * tokens_input,
- const int n_tokens,
- const int n_batch) {
-
- const int n_past = 0;
- const int N = n_tokens;
-
- const auto & hparams = model->hparams;
- //const int n_ctx = hparams.n_ctx;
- const int n_vocab = hparams.n_vocab;
- const int n_embd = hparams.n_embd;
- const int n_layer = hparams.n_layer;
- const int n_head = hparams.n_head;
- const int n_rot = hparams.n_rot;
- const int n_ff = get_n_ff(&hparams);
-
- struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch);
- memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch);
-
- // inpL shape [n_embd,N*n_batch,1]
- struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
- assert_shape_2d(inpL, n_embd, N*n_batch);
- for (int il = 0; il < n_layer; ++il) {
- struct ggml_tensor * inpSA = inpL;
-
- struct ggml_tensor * cur;
-
- // lctx.use_buf(ctx0, 0);
-
- // norm
- {
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
- assert_shape_2d(cur, n_embd, N*n_batch);
-
- // cur = attention_norm*cur
- cur = ggml_mul(ctx0,
- ggml_repeat(ctx0, model->layers[il].attention_norm, cur),
- cur);
- assert_shape_2d(cur, n_embd, N*n_batch);
- }
-
- // self-attention
- {
- // compute Q and K and RoPE them
- // wq shape [n_embd, n_embd, 1, 1]
- // wk shape [n_embd, n_embd, 1, 1]
- // Qcur shape [n_embd/n_head, n_head, N, n_batch]
- // Kcur shape [n_embd/n_head, n_head, N, n_batch]
- struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
- struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
- assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
- assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
-
- // Vcur shape [N, n_batch, n_embd/n_head, n_head]
- struct ggml_tensor * Vcur = ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, cur, model->layers[il].wv), N, n_batch, n_embd/n_head, n_head);
- assert_shape_4d(Vcur, N, n_batch, n_embd/n_head, n_head);
-
- // Qcur shape [n_embd/n_head, n_head, N, n_batch]
- // Q shape [n_embd/n_head, N, n_head, n_batch]
- struct ggml_tensor * Q =
- ggml_permute(ctx0,
- Qcur,
- 0, 2, 1, 3);
- assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch);
-
- // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer]
- // K shape [n_embd/n_head, N, n_head, n_batch]
- struct ggml_tensor * K =
- ggml_permute(ctx0,
- Kcur,
- 0, 2, 1, 3);
- assert_shape_4d(K, n_embd/n_head, N, n_head, n_batch);
-
- // K * Q
- // KQ shape [N, N, n_head, n_batch]
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
- assert_shape_4d(KQ, N, N, n_head, n_batch);
-
- // KQ_scaled = KQ / sqrt(n_embd/n_head)
- // KQ_scaled shape [N, N, n_head, n_batch]
- struct ggml_tensor * KQ_scaled =
- ggml_scale_inplace(ctx0,
- KQ,
- ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
- assert_shape_4d(KQ_scaled, N, N, n_head, n_batch);
-
- // KQ_masked = mask_past(KQ_scaled)
- // KQ_masked shape [N, N, n_head, n_batch]
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
- assert_shape_4d(KQ_masked, N, N, n_head, n_batch);
-
- // KQ = soft_max(KQ_masked)
- // KQ_soft_max shape [N, N, n_head, n_batch]
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
- assert_shape_4d(KQ_soft_max, N, N, n_head, n_batch);
-
- // Vcur shape [N, n_batch, n_embd/n_head, n_head]
- // V shape [N, n_embd/n_head, n_head, n_batch]
- struct ggml_tensor * V =
- ggml_permute(ctx0,
- Vcur,
- 0, 3, 1, 2);
- assert_shape_4d(V, N, n_embd/n_head, n_head, n_batch);
-
- // KQV shape [n_embd/n_head, N, n_head, n_batch]
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
- assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch);
-
- // KQV_merged = KQV.permute(0, 2, 1, 3)
- // KQV_merged shape [n_embd/n_head, n_head, N, n_batch]
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch);
- // KQV_merged shape
-
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch);
- assert_shape_2d(cur, n_embd, N*n_batch);
-
- // projection (no bias)
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_mul_mat(ctx0,
- model->layers[il].wo,
- cur);
- assert_shape_2d(cur, n_embd, N*n_batch);
- }
-
- // lctx.use_buf(ctx0, 1);
-
- // inpFF shape [n_embd,N*n_batch,1,1]
- struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA);
- assert_shape_2d(inpFF, n_embd, N*n_batch);
-
- // feed-forward network
- {
- // norm
- {
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
- assert_shape_2d(cur, n_embd, N*n_batch);
-
- // cur = ffn_norm*cur
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_mul(ctx0,
- ggml_repeat(ctx0, model->layers[il].ffn_norm, cur),
- cur);
- assert_shape_2d(cur, n_embd, N*n_batch);
- }
-
- // tmp shape [n_ff,N*n_batch,1,1]
- struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
- model->layers[il].w3,
- cur);
- assert_shape_2d(tmp, n_ff, N*n_batch);
-
- // cur shape [n_ff,N*n_batch,1,1]
- cur = ggml_mul_mat(ctx0,
- model->layers[il].w1,
- cur);
- assert_shape_2d(cur, n_ff, N*n_batch);
-
- // SILU activation
- // cur shape [n_ff,N*n_batch,1,1]
- cur = ggml_silu(ctx0, cur);
- assert_shape_2d(cur, n_ff, N*n_batch);
-
- // cur shape [n_ff,N*n_batch,1,1]
- cur = ggml_mul(ctx0, cur, tmp);
- assert_shape_2d(cur, n_ff, N*n_batch);
-
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_mul_mat(ctx0,
- model->layers[il].w2,
- cur);
- assert_shape_2d(cur, n_embd, N*n_batch);
- }
+static bool hash_contains(void * hash_table[], void * p) {
+ size_t i = hash_find(hash_table, p);
+ return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
+}
- // cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_add_inplace(ctx0, cur, inpFF);
- assert_shape_2d(cur, n_embd, N*n_batch);
+struct hash_map {
+ void * keys[GGML_GRAPH_HASHTABLE_SIZE];
+ void * vals[GGML_GRAPH_HASHTABLE_SIZE];
+};
+//static const size_t HASH_MAP_SIZE = sizeof(struct hash_map);
- // input for next layer
- // inpL shape [n_embd,N*n_batch,1,1]
- inpL = cur;
- assert_shape_2d(inpL, n_embd, N*n_batch);
+struct hash_map * new_hash_map() {
+ struct hash_map * result = new struct hash_map;
+ for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
+ result->keys[i] = NULL;
+ result->vals[i] = NULL;
}
+ return result;
+};
- // norm
- {
-
- // inpL shape [n_embd,N*n_batch,1,1]
- inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
- assert_shape_2d(inpL, n_embd, N*n_batch);
-
- // inpL = norm*inpL
- // inpL shape [n_embd,N*n_batch,1,1]
- inpL = ggml_mul(ctx0,
- ggml_repeat(ctx0, model->norm, inpL),
- inpL);
-
- assert_shape_2d(inpL, n_embd, N*n_batch);
-
- //embeddings = inpL;
- }
+void free_hash_map(struct hash_map * map) {
+ delete map;
+}
- // lm_head
- // inpL shape [n_vocab,N*n_batch,1,1]
- inpL = ggml_mul_mat(ctx0, model->output, inpL);
- assert_shape_2d(inpL, n_vocab, N*n_batch);
+static bool ggml_is_view(struct ggml_tensor * t) {
+ return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
+ t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
+}
- {
- // inpL shape [n_vocab,N,n_batch,1]
- inpL = ggml_reshape_3d(ctx0,
- inpL,
- n_vocab, N, n_batch);
- assert_shape_3d(inpL, n_vocab, N, n_batch);
+static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
+ switch (t->op) {
+ case GGML_OP_PERMUTE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_VIEW:
+ return t->src[0];
+ case GGML_OP_CPY:
+ return t->src[1];
+ default:
+ return NULL;
}
-
- // run the computation
- ggml_build_forward_expand(gf, inpL);
-
- return inpL;
}
-struct ggml_tensor * forward_batch_wo_cache_flash_attn(
- struct my_llama_model * model,
- struct ggml_context * ctx0,
- struct ggml_cgraph * gf,
- struct ggml_tensor * tokens_input,
- const int n_tokens,
- const int n_batch) {
-
- const int n_past = 0;
- const int N = n_tokens;
-
- const auto & hparams = model->hparams;
- //const int n_ctx = hparams.n_ctx;
- const int n_vocab = hparams.n_vocab;
- const int n_embd = hparams.n_embd;
- const int n_layer = hparams.n_layer;
- const int n_head = hparams.n_head;
- const int n_rot = hparams.n_rot;
- const int n_ff = get_n_ff(&hparams);
-
- struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch);
- memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch);
-
- struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
- assert_shape_2d(inpL, n_embd, N*n_batch);
- for (int il = 0; il < n_layer; ++il) {
- struct ggml_tensor * inpSA = inpL;
-
- struct ggml_tensor * cur;
-
- // norm
- {
- cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
- assert_shape_2d(cur, n_embd, N*n_batch);
-
- // cur = attention_norm*cur
- cur = ggml_mul(ctx0,
- ggml_repeat(ctx0, model->layers[il].attention_norm, cur),
- cur);
- assert_shape_2d(cur, n_embd, N*n_batch);
- }
-
- // self-attention
- {
- // compute Q and K and RoPE them
- // wq shape [n_embd, n_embd, 1, 1]
- // wk shape [n_embd, n_embd, 1, 1]
- struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
- struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
- assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
- assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
-
- struct ggml_tensor * Vcur = ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, cur, model->layers[il].wv), N, n_batch, n_embd/n_head, n_head);
- assert_shape_4d(Vcur, N, n_batch, n_embd/n_head, n_head);
-
- struct ggml_tensor * Q =
- ggml_permute(ctx0,
- Qcur,
- 0, 2, 1, 3);
- assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch);
-
- struct ggml_tensor * K =
- ggml_permute(ctx0,
- Kcur,
- 0, 2, 1, 3);
- assert_shape_4d(K, n_embd/n_head, N, n_head, n_batch);
-
- struct ggml_tensor * V =
- ggml_permute(ctx0,
- Vcur,
- 0, 3, 1, 2);
- assert_shape_4d(V, N, n_embd/n_head, n_head, n_batch);
-
- bool masked = true;
- struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, masked);
- assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch);
-
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch);
- cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch);
- assert_shape_2d(cur, n_embd, N*n_batch);
-
- // projection (no bias)
- cur = ggml_mul_mat(ctx0,
- model->layers[il].wo,
- cur);
- assert_shape_2d(cur, n_embd, N*n_batch);
- }
-
- struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA);
- assert_shape_2d(inpFF, n_embd, N*n_batch);
-
- // feed-forward network
- {
- // norm
- {
- cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
- assert_shape_2d(cur, n_embd, N*n_batch);
-
- // cur = ffn_norm*cur
- cur = ggml_mul(ctx0,
- ggml_repeat(ctx0, model->layers[il].ffn_norm, cur),
- cur);
- assert_shape_2d(cur, n_embd, N*n_batch);
- }
+static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
+ struct ggml_tensor * parent = t;
+ do {
+ parent = get_view_parent(parent);
+ } while (ggml_is_view(parent));
+ return parent;
+}
- struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
- model->layers[il].w3,
- cur);
- assert_shape_2d(tmp, n_ff, N*n_batch);
+struct ggml_tensor * ggml_recompute_graph_node(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * graph,
+ struct hash_map * replacements,
+ struct ggml_tensor * node) {
- cur = ggml_mul_mat(ctx0,
- model->layers[il].w1,
- cur);
- assert_shape_2d(cur, n_ff, N*n_batch);
+ if (node == NULL) {
+ return NULL;
+ }
- // SILU activation
- cur = ggml_silu(ctx0, cur);
- assert_shape_2d(cur, n_ff, N*n_batch);
+ if (node->is_param) {
+ return node;
+ }
- cur = ggml_mul(ctx0, cur, tmp);
- assert_shape_2d(cur, n_ff, N*n_batch);
+ if (!hash_contains(graph->visited_hash_table, node)) {
+ return node;
+ }
- cur = ggml_mul_mat(ctx0,
- model->layers[il].w2,
- cur);
- assert_shape_2d(cur, n_embd, N*n_batch);
+ int count_children = 0;
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
+ if (node->src[k]) {
+ ++count_children;
}
-
- cur = ggml_add_inplace(ctx0, cur, inpFF);
- assert_shape_2d(cur, n_embd, N*n_batch);
-
- // input for next layer
- inpL = cur;
- assert_shape_2d(inpL, n_embd, N*n_batch);
}
- // norm
- {
-
- inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
- assert_shape_2d(inpL, n_embd, N*n_batch);
-
- // inpL = norm*inpL
- inpL = ggml_mul(ctx0,
- ggml_repeat(ctx0, model->norm, inpL),
- inpL);
-
- assert_shape_2d(inpL, n_embd, N*n_batch);
+ if (count_children == 0) {
+ return node;
}
- // lm_head
- inpL = ggml_mul_mat(ctx0, model->output, inpL);
- assert_shape_2d(inpL, n_vocab, N*n_batch);
-
- {
- inpL = ggml_reshape_3d(ctx0,
- inpL,
- n_vocab, N, n_batch);
- assert_shape_3d(inpL, n_vocab, N, n_batch);
+ size_t i = hash_find(replacements->keys, node);
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+ if (replacements->keys[i] == node) {
+ return (struct ggml_tensor *) replacements->vals[i];
}
- // run the computation
- ggml_build_forward_expand(gf, inpL);
+ struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
- return inpL;
-}
+ // insert clone into replacements
+ GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
+ replacements->keys[i] = node;
+ replacements->vals[i] = clone;
-// expand the graph nodes without creating leafs.
-struct ggml_tensor * expand(struct ggml_cgraph * g, struct ggml_tensor * t) {
- // check if already visited
- for (int i = 0; i < g->n_nodes; i++) {
- if (g->nodes[i] == t) {
- return t;
- }
+ clone->op = node->op;
+ clone->grad = node->grad;
+ clone->is_param = node->is_param;
+ clone->extra = node->extra;
+ for (int k = 0; k < GGML_MAX_DIMS; ++k) {
+ clone->nb[k] = node->nb[k];
}
-
- for (int i = 0; i < g->n_leafs; i++) {
- if (g->leafs[i] == t) {
- return t;
- }
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
+ clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
}
-
- for (int i = 0; i < GGML_MAX_SRC; ++i) {
- if (t->src[i]) {
- expand(g, t->src[i]);
- }
+ if (ggml_is_view(clone)) {
+ struct ggml_tensor * source = get_view_source(clone);
+ GGML_ASSERT(source != NULL);
+ clone->data = source->data;
}
- GGML_ASSERT(g->n_nodes < GGML_MAX_NODES);
+ GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
+ GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
+ memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
+ ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
- if (strlen(t->name) == 0) {
- snprintf(t->name, sizeof(t->name), "node_%d", g->n_nodes);
- }
-
- g->nodes[g->n_nodes] = t;
- g->grads[g->n_nodes] = t->grad;
- g->n_nodes++;
- return t;
-}
+ return clone;
+};
-void graph_set_leafs_grads(struct ggml_cgraph * g) {
- // moves leaf nodes to g->leafs.
- // i.e. g->n_nodes might change.
- int n_nodes = 0;
- for (int i = 0; i < g->n_nodes; ++i) {
- struct ggml_tensor * node = g->nodes[i];
- const bool is_leaf = node->op == GGML_OP_NONE && node->grad == NULL;
- if (is_leaf) {
- GGML_ASSERT(g->n_leafs < GGML_MAX_NODES);
-
- if (strlen(node->name) == 0) {
- snprintf(node->name, sizeof(node->name), "leaf_%d", g->n_leafs);
- }
+void ggml_build_backward_gradient_checkpointing(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ struct ggml_cgraph * gb_tmp,
+ struct ggml_tensor * * checkpoints,
+ int n_checkpoints) {
+ *gb_tmp = *gf;
+ ggml_build_backward_expand(ctx, gf, gb_tmp, true);
+
+ if (n_checkpoints <= 0) {
+ *gb = *gb_tmp;
+ return;
+ }
- g->leafs[g->n_leafs] = node;
- g->n_leafs++;
- } else {
- GGML_ASSERT(n_nodes < GGML_MAX_NODES);
+ struct hash_map * replacements = new_hash_map();
- if (strlen(node->name) == 0) {
- snprintf(node->name, sizeof(node->name), "node_%d", n_nodes);
- }
+ // insert checkpoints in replacements
+ for (int i = 0; i < n_checkpoints; ++i) {
+ size_t k = hash_find(replacements->keys, checkpoints[i]);
+ GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+ GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
+ replacements->keys[k] = checkpoints[i];
+ replacements->vals[k] = checkpoints[i];
+ }
- g->nodes[n_nodes] = node;
- g->grads[n_nodes] = node->grad;
- n_nodes++;
+ *gb = *gf;
+ // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
+ // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
+ // by recomputing them from checkpoints
+ for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
+ struct ggml_tensor * node = gb_tmp->nodes[i];
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
+ // insert new tensors recomputing src, reusing already made replacements,
+ // remember replacements: remember new tensors with mapping from corresponding gf nodes
+ // recurse for input tensors,
+ // unless (i.e. terminating when) input tensors are checkpoints
+ node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
}
+ // insert rewritten backward node with replacements made into resulting backward graph gb
+ ggml_build_forward_expand(gb, node);
}
- for (int i=n_nodes; i < g->n_nodes; ++i) {
- g->nodes[n_nodes] = NULL;
- g->grads[n_nodes] = NULL;
- }
- g->n_nodes = n_nodes;
+
+ free_hash_map(replacements);
}
-struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
+struct ggml_tensor * llama_build_train_graphs(
struct my_llama_model * model,
- struct ggml_context * ctx0,
+ struct ggml_allocr * alloc,
+ struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
+ struct ggml_cgraph * gb_tmp,
struct ggml_tensor * * logits,
struct ggml_tensor * tokens_input,
struct ggml_tensor * targets,
- void * compute_buf_0,
- void * compute_buf_1,
- size_t size_buf_0,
- size_t size_buf_1,
const int n_tokens,
- const int n_batch) {
-
- ggml_set_scratch(ctx0, { 0, 0, nullptr, });
+ const int n_batch,
+ const bool enable_flash_attn,
+ const bool enable_checkpointing) {
+ ggml_set_scratch(ctx, { 0, 0, nullptr, });
const int n_past = 0;
const int N = n_tokens;
-
- gf->n_nodes = 0;
- gf->n_leafs = 0;
- gf->perf_runs = 0;
- gf->perf_cycles = 0;
- gf->perf_time_us = 0;
-
const auto & hparams = model->hparams;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
const int n_layer = hparams.n_layer;
const int n_head = hparams.n_head;
const int n_rot = hparams.n_rot;
- const int n_ff = get_n_ff(&hparams);
- const int rope_mode = 0;
-
- int last_buf = -1;
- size_t buf_offs[2] = { 0, 0 };
- size_t buf_size[2] = { size_buf_0,
- size_buf_1 };
- void * buf_data[2] = { compute_buf_0,
- compute_buf_1 };
- auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data] (int buf) {
- size_t last_offs = 0;
- last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, });
- if (last_buf >= 0) {
- buf_offs[last_buf] = last_offs;
- }
- if (buf >= 0) {
- size_t offs = buf_offs[buf];
- size_t size = buf_size[buf];
- void * data = buf_data[buf];
- ggml_set_scratch(ctx0, { offs, size, data, });
- }
- last_buf = buf;
- };
-
- bool track_max_mem = false;
- size_t buf_maxs[2] = { 0, 0 };
-
- auto clr_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs, track_max_mem] (int buf) {
- if (buf < 0) return;
- if (track_max_mem) {
- size_t last_offs = 0;
- last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, });
- if (last_buf >= 0) {
- buf_offs[last_buf] = last_offs;
- buf_maxs[last_buf] = std::max(buf_maxs[last_buf], buf_offs[last_buf]);
- }
- }
- buf_offs[buf] = 0;
- if (track_max_mem && last_buf >= 0) {
- size_t offs = buf_offs[last_buf];
- size_t size = buf_size[last_buf];
- void * data = buf_data[last_buf];
- ggml_set_scratch(ctx0, { offs, size, data, });
+ const int n_ff = hparams.n_ff;
+ const float f_norm_rms_eps = hparams.f_norm_rms_eps;
+ const float rope_freq_base = hparams.rope_freq_base;
+ const float rope_freq_scale = hparams.rope_freq_scale;
+
+ auto set_name = [](struct ggml_tensor * t, const char * n) {
+ ggml_set_name(t, n);
+ if (t->grad) {
+ ggml_format_name(t->grad, "%s->grad", n);
}
};
+ // rope has so much parameters that we make a custom function for it
+ auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
+ (struct ggml_tensor * t) -> struct ggml_tensor * {
+ // not capturing these, to silcence warnings
+ const int n_past = 0;
+ const int rope_mode = 0;
- auto view__q = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * {
- int64_t ne0 = n_embd/n_head;
- int64_t ne1 = N;
- int64_t ne2 = n_head;
- int64_t ne3 = n_batch;
- size_t nb0 = ggml_element_size(t);
- size_t nb1 = nb0*ne0;
- size_t nb2 = nb1*ne1;
- size_t nb3 = nb2*ne2;
- size_t offset = 0;
- return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset);
+ return ggml_rope_custom(ctx,
+ t, n_past, n_rot, rope_mode, n_ctx,
+ rope_freq_base, rope_freq_scale);
};
- auto view__k = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * {
- int64_t ne0 = n_embd/n_head;
- int64_t ne1 = N;
- int64_t ne2 = n_head;
- int64_t ne3 = n_batch;
- size_t nb0 = ggml_element_size(t);
- size_t nb1 = nb0*ne0;
- size_t nb2 = nb1*ne1;
- size_t nb3 = nb2*ne2;
- size_t offset = nb3*ne3;
- return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset);
- };
+ set_name(tokens_input, "tokens_input");
+ set_name(targets, "targets");
- auto view__v = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * {
- int64_t ne0 = N;
- int64_t ne1 = n_embd/n_head;
- int64_t ne2 = n_head;
- int64_t ne3 = n_batch;
- size_t nb0 = ggml_element_size(t);
- size_t nb1 = nb0*ne0;
- size_t nb2 = nb1*ne1;
- size_t nb3 = nb2*ne2;
- size_t offset = 2*nb3*ne3;
- return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset);
- };
+ GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
+ struct ggml_tensor * t00 = ggml_reshape_1d(ctx, tokens_input, N*n_batch); set_name(t00, "t00"); assert_shape_1d(t00, N*n_batch);
+ struct ggml_tensor * t01 = ggml_get_rows(ctx, model->tok_embeddings, t00); set_name(t01, "t01"); assert_shape_2d(t01, n_embd, N*n_batch);
- auto add_or_set = [ctx0] (struct ggml_tensor * a, struct ggml_tensor * b) -> struct ggml_tensor * {
- if (a == NULL) {
- return b;
- } else {
- return ggml_add_inplace(ctx0, a, b);
- }
- };
-
- use_buf(-1);
+ struct ggml_tensor * cur = t01;
- model->tok_embeddings->grad = NULL;
- model->norm->grad = NULL;
- model->output->grad = NULL;
+ std::vector<struct ggml_tensor *> checkpoints;
+ checkpoints.push_back(tokens_input);
+ checkpoints.push_back(targets);
+ checkpoints.push_back(t00);
+ checkpoints.push_back(t01);
- for (int il = 0; il < n_layer; ++il) {
- struct my_llama_layer & layer = model->layers[il];
- layer.attention_norm->grad = NULL;
- layer.wq->grad = NULL;
- layer.wk->grad = NULL;
- layer.wv->grad = NULL;
- layer.wo->grad = NULL;
- layer.ffn_norm->grad = NULL;
- layer.w1->grad = NULL;
- layer.w2->grad = NULL;
- layer.w3->grad = NULL;
+ struct ggml_tensor * kv_scale;
+ if (!enable_flash_attn) {
+ kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
}
- clr_buf(0);
- clr_buf(1);
-
- use_buf(-1);
-
- struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch);
- memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch);
-
- use_buf(-1);
-
- struct ggml_tensor * t01 = expand(gf, ggml_get_rows(ctx0, model->tok_embeddings, t00)); assert_shape_2d(t01, n_embd, N*n_batch);
-
- // need to remember these for the backward pass
- std::vector<struct ggml_tensor *> t02L; t02L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t03L; t03L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t04L; t04L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t05L; t05L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t06L; t06L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t07L; t07L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t08L; t08L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t09L; t09L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t10L; t10L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t11L; t11L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t12L; t12L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t13L; t13L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t14L; t14L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t15L; t15L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t16L; t16L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t17L; t17L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t18L; t18L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t19L; t19L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t20L; t20L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t21L; t21L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t22L; t22L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t23L; t23L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t24L; t24L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t25L; t25L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t26L; t26L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t27L; t27L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t28L; t28L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t29L; t29L.resize(n_layer, NULL);
- std::vector<struct ggml_tensor *> t30L; t30L.resize(n_layer, NULL);
-
- struct ggml_tensor * cur = t01;
-
for (int il = 0; il < n_layer; ++il) {
- clr_buf(0);
struct my_llama_layer & layer = model->layers[il];
- // tensors with values necessary for backward pass are in persistent buf(-1)
- // other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
- use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
- use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
- use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
- use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
- use_buf(-1); struct ggml_tensor * t06 = expand(gf, ggml_reshape_4d (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
- use_buf(-1); struct ggml_tensor * t07 = expand(gf, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode, 0)); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
- use_buf(-1); struct ggml_tensor * t08 = expand(gf, ggml_mul_mat (ctx0, layer.wk, t04)); assert_shape_2d(t08, n_embd, N*n_batch);
- use_buf(-1); struct ggml_tensor * t09 = expand(gf, ggml_reshape_4d (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
- use_buf(-1); struct ggml_tensor * t10 = expand(gf, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode, 0)); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
- use_buf(-1); struct ggml_tensor * t11 = expand(gf, ggml_mul_mat (ctx0, t04, layer.wv)); assert_shape_2d(t11, N*n_batch, n_embd);
- use_buf(-1); struct ggml_tensor * t12 = expand(gf, ggml_reshape_4d (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
- use_buf(-1); struct ggml_tensor * t13 = expand(gf, ggml_permute (ctx0, t07, 0, 2, 1, 3)); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
- use_buf(-1); struct ggml_tensor * t14 = expand(gf, ggml_permute (ctx0, t10, 0, 2, 1, 3)); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
- use_buf(-1); struct ggml_tensor * t15 = expand(gf, ggml_permute (ctx0, t12, 0, 3, 1, 2)); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
- use_buf(-1); struct ggml_tensor * t16 = expand(gf, ggml_flash_attn (ctx0, t13, t14, t15, true)); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
- use_buf( 0); struct ggml_tensor * t17 = expand(gf, ggml_permute (ctx0, t16, 0, 2, 1, 3)); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
- use_buf(-1); struct ggml_tensor * t18 = expand(gf, ggml_cont (ctx0, t17)); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
- use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
- use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
- use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
- use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
- use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
- use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
- use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
- use_buf(-1); struct ggml_tensor * t26 = expand(gf, ggml_mul_mat (ctx0, layer.w1, t24)); assert_shape_2d(t26, n_ff, N*n_batch);
- use_buf(-1); struct ggml_tensor * t27 = expand(gf, ggml_silu (ctx0, t26)); assert_shape_2d(t27, n_ff, N*n_batch);
- use_buf(-1); struct ggml_tensor * t28 = expand(gf, ggml_mul (ctx0, t27, t25)); assert_shape_2d(t28, n_ff, N*n_batch);
- use_buf( 0); struct ggml_tensor * t29 = expand(gf, ggml_mul_mat (ctx0, layer.w2, t28)); assert_shape_2d(t29, n_embd, N*n_batch);
- use_buf(-1); struct ggml_tensor * t30 = expand(gf, ggml_add (ctx0, t21, t29)); assert_shape_2d(t30, n_embd, N*n_batch);
- t02L[il] = t02;
- t03L[il] = t03;
- t04L[il] = t04;
- t05L[il] = t05;
- t06L[il] = t06;
- t07L[il] = t07;
- t08L[il] = t08;
- t09L[il] = t09;
- t10L[il] = t10;
- t11L[il] = t11;
- t12L[il] = t12;
- t13L[il] = t13;
- t14L[il] = t14;
- t15L[il] = t15;
- t16L[il] = t16;
- t17L[il] = t17;
- t18L[il] = t18;
- t19L[il] = t19;
- t20L[il] = t20;
- t21L[il] = t21;
- t22L[il] = t22;
- t23L[il] = t23;
- t24L[il] = t24;
- t25L[il] = t25;
- t26L[il] = t26;
- t27L[il] = t27;
- t28L[il] = t28;
- t29L[il] = t29;
- t30L[il] = t30;
-
- cur = t30;
- }
- clr_buf(0);
- use_buf(0);
- struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
- struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
- struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
- use_buf(-1);
- struct ggml_tensor * t34 = expand(gf, ggml_mul_mat (ctx0, model->output, t33)); assert_shape_2d(t34, n_vocab, N*n_batch);
- struct ggml_tensor * t35 = expand(gf, ggml_reshape_3d(ctx0, t34, n_vocab, N, n_batch)); assert_shape_3d(t35, n_vocab, N, n_batch);
- struct ggml_tensor * t36 = expand(gf, ggml_cross_entropy_loss(ctx0, t35, targets)); assert_shape_1d(t36, 1);
-
- {
- /*
- tok_embeddings | grad_tok_embeddings = ggml_get_rows_back(grad_t01, t00)
- L0_att_norm | grad_L0_att_norm = ggml_repeat_back(grad_t03L0, L0_att_norm.shape)
- L0_wq | grad_L0_wq = ggml_out_prod(t04L0, grad_t05L0)
- L0_wk | grad_L0_wk = ggml_out_prod(t04L0, grad_t08L0)
- L0_wv | grad_L0_wv = ggml_out_prod(t04L0, ggml_transpose(grad_t11L0))
- L0_wo | grad_L0_wo = ggml_out_prod(t19L0, grad_t20L0)
- L0_ffn_norm | grad_L0_ffn_norm = ggml_repeat_back(grad_t23L0, L0_ffn_norm.shape)
- L0_w1 | grad_L0_w1 = ggml_out_prod(t24L0, grad_t26L0)
- L0_w2 | grad_L0_w2 = ggml_out_prod(t28L0, grad_t29L0)
- L0_w3 | grad_L0_w3 = ggml_out_prod(t24L0, grad_t25L0)
- L1_att_norm | grad_L1_att_norm = ggml_repeat_back(grad_t03L1, L1_att_norm.shape)
- L1_wq | grad_L1_wq = ggml_out_prod(t04L1, grad_t05L1)
- L1_wk | grad_L1_wk = ggml_out_prod(t04L1, grad_t08L1)
- L1_wv | grad_L1_wv = ggml_out_prod(t04L1, ggml_transpose(grad_t11L1))
- L1_wo | grad_L1_wo = ggml_out_prod(t19L1, grad_t20L1)
- L1_ffn_norm | grad_L1_ffn_norm = ggml_repeat_back(grad_t23L1, L1_ffn_norm.shape)
- L1_w1 | grad_L1_w1 = ggml_out_prod(t24L1, grad_t26L1)
- L1_w2 | grad_L1_w2 = ggml_out_prod(t28L1, grad_t29L1)
- L1_w3 | grad_L1_w3 = ggml_out_prod(t24L1, grad_t25L1)
- norm | grad_norm = ggml_repeat_back(grad_t32, norm.shape)
- output | grad_output = ggml_out_prod(t33, grad_t34)
- |
- t01 = ggml_get_rows(tok_embeddings, t00) | grad_t01 = grad_t21L0 + ggml_rms_norm_back(t01, grad_t02L0)
- for layer: |
- t02L0*= ggml_rms_norm (t01) | grad_t02L0 = ggml_mul(grad_t04L0, t03L0)
- t03L0 = ggml_repeat (L0_att_norm, t02L0_shape) | grad_t03L0 = ggml_mul(grad_t04L0, t02L0)
- t04L0*= ggml_mul (t02L0, t03L0) | grad_t04L0 = ggml_out_prod(L0_wv, grad_t11L0) + ggml_out_prod(L0_wk, ggml_transpose(grad_t08L0)) + ggml_out_prod(L0_wq, ggml_transpose(grad_t05L0))
- t05L0 = ggml_mul_mat (L0_wq, t04L0) | grad_t05L0 = ggml_reshape(grad_t06L0, t05L0_shape)
- t06L0 = ggml_reshape_4d (t05L0, n_embd/n_head, n_head, N, n_batch) | grad_t06L0 = ggml_rope_back(grad_t07L0)
- t07L0 = ggml_rope_inplace (t06L0) | grad_t07L0 = ggml_permute_back(grad_t13L0, 0, 2, 1, 3) = ggml_permute(grad_t13L0, 0, 2, 1, 3)
- t08L0 = ggml_mul_mat (L0_wk, t04L0) | grad_t08L0 = ggml_reshape(grad_t09L0, t08L0_shape)
- t09L0 = ggml_reshape_4d (t08L0, n_embd/n_head, n_head, N, n_batch) | grad_t09L0 = ggml_rope_back(grad_t10L0)
- t10L0 = ggml_rope_inplace (t09L0) | grad_t10L0 = ggml_permute_back(grad_t14L0, 0, 2, 1, 3) = ggml_permute(grad_t14L0, 0, 2, 1, 3)
- t11L0 = ggml_mul_mat (t04L0, L0_wv) | grad_t11L0 = ggml_reshape(grad_t12L0, t11L0_shape)
- t12L0 = ggml_reshape_4d (t11L0, N, n_batch, n_embd/n_head, n_head) | grad_t12L0 = ggml_permute_back(grad_t15L0, 0, 3, 1, 2) = ggml_permute(grad_t15L0, 0, 2, 3, 1)
- t13L0*= ggml_permute (t07L0, 0, 2, 1, 3) | grad_t13L0 = view__q(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0))
- t14L0*= ggml_permute (t10L0, 0, 2, 1, 3) | grad_t14L0 = view__k(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0))
- t15L0*= ggml_permute (t12L0, 0, 3, 1, 2) | grad_t15L0 = view__v(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0))
- t16L0 = ggml_flash_attn (t13L0, t14L0, t15L0) | grad_t16L0 = ggml_permute_back(grad_t17L0, 0, 2, 1, 3) = ggml_permute(grad_t17L0, 0, 2, 1, 3)
- t17L0 = ggml_permute (t16L0, 0, 2, 1, 3) | grad_t17L0 = grad_t18L0
- t18L0 = ggml_cont (t17L0) | grad_t18L0 = ggml_reshape(grad_t19L0, t18L0_shape)
- t19L0*= ggml_reshape_2d (t18L0, n_embd, N*n_batch) | grad_t19L0 = ggml_out_prod(L0_wo, ggml_transpose(grad_t20L0))
- t20L0 = ggml_mul_mat (L0_wo, t19L0) | grad_t20L0 = grad_t21L0
- t21L0*= ggml_add (t20L0, t01) | grad_t21L0 = grad_t30L0 + ggml_rms_norm_back(t21L0, grad_t22L0)
- t22L0*= ggml_rms_norm (t21L0) | grad_t22L0 = ggml_mul(grad_t24L0, t23L0)
- t23L0 = ggml_repeat (L0_ffn_norm, t22L0_shape) | grad_t23L0 = ggml_mul(grad_t24L0, t22L0)
- t24L0*= ggml_mul (t23L0, t22L0) | grad_t24L0 = ggml_out_prod(L0_w1, ggml_transpose(grad_t26L0)) + ggml_out_prod(L0_w3, ggml_transpose(grad_t25L0))
- t25L0*= ggml_mul_mat (L0_w3, t24L0) | grad_t25L0 = ggml_mul(grad_t28L0, t27L0)
- t26L0*= ggml_mul_mat (L0_w1, t24L0) | grad_t26L0 = ggml_silu_back(t26L0, grad_t27L0)
- t27L0*= ggml_silu (t26L0) | grad_t27L0 = ggml_mul(grad_t28L0, t25L0)
- t28L0*= ggml_mul (t27L0, t25L0) | grad_t28L0 = ggml_out_prod(L0_w2, ggml_transpose(grad_t29L0))
- t29L0 = ggml_mul_mat (L0_w2, t28L0) | grad_t29L0 = grad_t30L0
- t30L0*= ggml_add (t21L0, t29L0) | grad_t30L0 = ggml_rms_norm_back(t30L0, grad_t02L1) + grad_t21L1
- ^
- t02L1*= ggml_rms_norm (t30L0) | grad_t02L1 = ggml_mul(grad_t04L1, t03L1)
- t03L1 = ggml_repeat (L1_att_norm, t02L1_shape) | grad_t03L1 = ggml_mul(grad_t04L1, t02L1)
- t04L1*= ggml_mul (t02L1, t03L1) | grad_t04L1 = ggml_out_prod(L1_wv, grad_t11L1) + ggml_out_prod(L1_wk, ggml_transpose(grad_t08L1)) + ggml_out_prod(L1_wq, ggml_transpose(grad_t05L1))
- t05L1 = ggml_mul_mat (L1_wq, t04L1) | grad_t05L1 = ggml_reshape(grad_t06L1, t05L1_shape)
- t06L1 = ggml_reshape_4d (t05L1, n_embd/n_head, n_head, N, n_batch) | grad_t06L1 = ggml_rope_back(grad_t07L1)
- t07L1 = ggml_rope_inplace (t06L1) | grad_t07L1 = ggml_permute_back(grad_t13L1, 0, 2, 1, 3) = ggml_permute(grad_t13L1, 0, 2, 1, 3)
- t08L1 = ggml_mul_mat (L1_wk, t04L1) | grad_t08L1 = ggml_reshape(grad_t09L1, t08L1_shape)
- t09L1 = ggml_reshape_4d (t08L1, n_embd/n_head, n_head, N, n_batch) | grad_t09L1 = ggml_rope_back(grad_t10L1)
- t10L1 = ggml_rope_inplace (t09L1) | grad_t10L1 = ggml_permute_back(grad_t14L1, 0, 2, 1, 3) = ggml_permute(grad_t14L1, 0, 2, 1, 3)
- t11L1 = ggml_mul_mat (t04L1, L1_wv) | grad_t11L1 = ggml_reshape(grad_t12L1, t11L1_shape)
- t12L1 = ggml_reshape_4d (t11L1, N, n_batch, n_embd/n_head, n_head) | grad_t12L1 = ggml_permute_back(grad_t15L1, 0, 3, 1, 2) = ggml_permute(grad_t15L1, 0, 2, 3, 1)
- t13L1*= ggml_permute (t07L1, 0, 2, 1, 3) | grad_t13L1 = view__q(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1))
- t14L1*= ggml_permute (t10L1, 0, 2, 1, 3) | grad_t14L1 = view__k(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1))
- t15L1*= ggml_permute (t12L1, 0, 3, 1, 2) | grad_t15L1 = view__v(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1))
- t16L1 = ggml_flash_attn (t13L1, t14L1, t15L1) | grad_t16L1 = ggml_permute_back(grad_t17L1, 0, 2, 1, 3) = ggml_permute(grad_t17L1, 0, 2, 1, 3)
- t17L1 = ggml_permute (t16L1, 0, 2, 1, 3) | grad_t17L1 = grad_t18L1
- t18L1 = ggml_cont (t17L1) | grad_t18L1 = ggml_reshape(grad_t19L1, t18L1_shape)
- t19L1*= ggml_reshape_2d (t18L1, n_embd, N*n_batch) | grad_t19L1 = ggml_out_prod(L1_wo, ggml_transpose(grad_t20L1))
- t20L1 = ggml_mul_mat (L1_wo, t19L1) | grad_t20L1 = grad_t21L1
- t21L1*= ggml_add (t20L1, t30L0) | grad_t21L1 = grad_t30L1 + ggml_rms_norm_back(t21L1, grad_t22L1)
- t22L1*= ggml_rms_norm (t21L1) | grad_t22L1 = ggml_mul(grad_t24L1, t23L1)
- t23L1 = ggml_repeat (L1_ffn_norm, t22L1_shape) | grad_t23L1 = ggml_mul(grad_t24L1, t22L1)
- t24L1*= ggml_mul (t23L1, t22L1) | grad_t24L1 = ggml_out_prod(L1_w1, ggml_transpose(grad_t26L1)) + ggml_out_prod(L1_w3, ggml_transpose(grad_t25L1))
- t25L1*= ggml_mul_mat (L1_w3, t24L1) | grad_t25L1 = ggml_mul(grad_t28L1, t27L1)
- t26L1*= ggml_mul_mat (L1_w1, t24L1) | grad_t26L1 = ggml_silu_back(t26L1, grad_t27L1)
- t27L1*= ggml_silu (t26L1) | grad_t27L1 = ggml_mul(grad_t28L1, t25L1)
- t28L1*= ggml_mul (t27L1, t25L1) | grad_t28L1 = ggml_out_prod(L1_w2, ggml_transpose(grad_t29L1))
- t29L1 = ggml_mul_mat (L1_w2, t28L1) | grad_t29L1 = grad_t30L1
- t30L1*= ggml_add (t21L1, t29L1) | grad_t30L1 = ggml_rms_norm_back(t30L1, grad_t31)
- ^
- t31 = ggml_rms_norm (t30L1) | grad_t31 = ggml_mul(grad_t33, t32)
- t32 = ggml_repeat (norm, t31.shape) | grad_t32 = ggml_mul(grad_t33, t31)
- t33 = ggml_mul (t32, t31) | grad_t33 = ggml_out_prod(output, ggml_transpose(grad_t34))
- t34 = ggml_mul_mat (output, t33) | grad_t34 = ggml_reshape(grad_t35, t34.shape)
- t35 = ggml_reshape_3d (t34, n_vocab, N, n_batch) | grad_t35 = ggml_cross_entropy_loss_back(t35, targets, grad_t36)
- t36 = ggml_cross_entropy_loss(t35, targets) | grad_t36 = 1 (optimizer)
- tensors marked with * need to be stored until grad computation
- tensors during grad computation are all temporary
- */
- }
-
- *gb = *gf;
-
- // t36->grad gets set to one by optimizer, so we need the tensor.
- // initialize it with 1.0f to make sure.
- use_buf(-1);
- t36->grad = expand(gb, ggml_new_f32(ctx0, 1.0f));
-
- use_buf(0);
- t35->grad = expand(gb, ggml_cross_entropy_loss_back(ctx0, t35, targets, t36->grad)); assert_shape_3d(t35->grad, n_vocab, N, n_batch);
- t34->grad = expand(gb, ggml_reshape_2d (ctx0, t35->grad, n_vocab, N*n_batch)); assert_shape_2d(t34->grad, n_vocab, N*n_batch);
- t33->grad = expand(gb, ggml_out_prod (ctx0, model->output, ggml_transpose(ctx0, t34->grad))); assert_shape_2d(t33->grad, n_embd, N*n_batch);
- t32->grad = expand(gb, ggml_mul (ctx0, t33->grad, t31)); assert_shape_2d(t32->grad, n_embd, N*n_batch);
-
- use_buf(-1);
-
- model->norm->grad = expand(gb, add_or_set(model->norm->grad, ggml_repeat_back(ctx0, t32->grad, model->norm))); assert_shape_1d(model->norm->grad, n_embd);
- model->output->grad = expand(gb, add_or_set(model->output->grad, ggml_out_prod(ctx0, t33, t34->grad))); assert_shape_2d(model->output->grad, n_embd, n_vocab);
-
- clr_buf(1);
- use_buf(1);
- t31->grad = expand(gb, ggml_mul(ctx0, t33->grad, t32)); assert_shape_2d(t31->grad, n_embd, N*n_batch);
+ struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, f_norm_rms_eps); set_name(t02, "t02"); assert_shape_2d(t02, n_embd, N*n_batch);
+ struct ggml_tensor * t03 = ggml_repeat (ctx, layer.attention_norm, t02); set_name(t03, "t03"); assert_shape_2d(t03, n_embd, N*n_batch);
+ struct ggml_tensor * t04 = ggml_mul (ctx, t03, t02); set_name(t04, "t04"); assert_shape_2d(t04, n_embd, N*n_batch);
+ struct ggml_tensor * t05 = ggml_mul_mat (ctx, layer.wq, t04); set_name(t05, "t05"); assert_shape_2d(t05, n_embd, N*n_batch);
+ struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd/n_head, n_head, N, n_batch); set_name(t06, "t06"); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
+ struct ggml_tensor * t07 = rope (t06); set_name(t07, "t07"); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
+ struct ggml_tensor * t08 = ggml_mul_mat (ctx, layer.wk, t04); set_name(t08, "t08"); assert_shape_2d(t08, n_embd, N*n_batch);
+ struct ggml_tensor * t09 = ggml_reshape_4d (ctx, t08, n_embd/n_head, n_head, N, n_batch); set_name(t09, "t09"); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
+ struct ggml_tensor * t10 = rope (t09); set_name(t10, "t10"); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
+ struct ggml_tensor * t11 = ggml_mul_mat (ctx, t04, layer.wv); set_name(t11, "t11"); assert_shape_2d(t11, N*n_batch, n_embd);
+ struct ggml_tensor * t12 = ggml_reshape_4d (ctx, t11, N, n_batch, n_embd/n_head, n_head); set_name(t12, "t12"); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
+ struct ggml_tensor * t13 = ggml_permute (ctx, t07, 0, 2, 1, 3); set_name(t13, "t13"); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
+ struct ggml_tensor * t14 = ggml_permute (ctx, t10, 0, 2, 1, 3); set_name(t14, "t14"); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
+ struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
+ struct ggml_tensor * t16;
+ if (enable_flash_attn) {
+ t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
+ } else {
+ struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
+ struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
+ struct ggml_tensor * t16_2 = ggml_diag_mask_inf_inplace(ctx, t16_1, n_past); set_name(t16_2, "t16_2"); assert_shape_4d(t16_2, N, N, n_head, n_batch);
+ struct ggml_tensor * t16_3 = ggml_soft_max_inplace (ctx, t16_2); set_name(t16_3, "t16_3"); assert_shape_4d(t16_3, N, N, n_head, n_batch);
+ t16 = ggml_mul_mat(ctx, t15, t16_3); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
+ }
+ struct ggml_tensor * t17 = ggml_permute (ctx, t16, 0, 2, 1, 3); set_name(t17, "t17"); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
+ struct ggml_tensor * t18 = ggml_cont (ctx, t17); set_name(t18, "t18"); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
+ struct ggml_tensor * t19 = ggml_reshape_2d (ctx, t18, n_embd, N*n_batch); set_name(t19, "t19"); assert_shape_2d(t19, n_embd, N*n_batch);
+ struct ggml_tensor * t20 = ggml_mul_mat (ctx, layer.wo, t19); set_name(t20, "t20"); assert_shape_2d(t20, n_embd, N*n_batch);
+ struct ggml_tensor * t21 = ggml_add (ctx, t20, cur); set_name(t21, "t21"); assert_shape_2d(t21, n_embd, N*n_batch);
+ struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, f_norm_rms_eps); set_name(t22, "t22"); assert_shape_2d(t22, n_embd, N*n_batch);
+ struct ggml_tensor * t23 = ggml_repeat (ctx, layer.ffn_norm, t22); set_name(t23, "t23"); assert_shape_2d(t23, n_embd, N*n_batch);
+ struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); set_name(t24, "t24"); assert_shape_2d(t24, n_embd, N*n_batch);
+ struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.w3, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch);
+ struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.w1, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch);
+ struct ggml_tensor * t27 = ggml_silu (ctx, t26); set_name(t27, "t27"); assert_shape_2d(t27, n_ff, N*n_batch);
+ struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); set_name(t28, "t28"); assert_shape_2d(t28, n_ff, N*n_batch);
+ struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.w2, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
+ struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch);
+ cur = t30;
+ checkpoints.push_back(cur);
+ }
+ struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, f_norm_rms_eps); set_name(t31, "t31"); assert_shape_2d(t31, n_embd, N*n_batch);
+ struct ggml_tensor * t32 = ggml_repeat (ctx, model->norm, t31); set_name(t32, "t32"); assert_shape_2d(t32, n_embd, N*n_batch);
+ struct ggml_tensor * t33 = ggml_mul (ctx, t32, t31); set_name(t33, "t33"); assert_shape_2d(t33, n_embd, N*n_batch);
+ struct ggml_tensor * t34 = ggml_mul_mat (ctx, model->output, t33); set_name(t34, "t34"); assert_shape_2d(t34, n_vocab, N*n_batch);
+ struct ggml_tensor * t35 = ggml_reshape_3d (ctx, t34, n_vocab, N, n_batch); set_name(t35, "t35"); assert_shape_3d(t35, n_vocab, N, n_batch);
+ struct ggml_tensor * t36 = ggml_cross_entropy_loss(ctx, t35, targets); set_name(t36, "t36"); assert_shape_1d(t36, 1);
+
+ checkpoints.push_back(t31);
+ checkpoints.push_back(t32);
+ checkpoints.push_back(t33);
+ checkpoints.push_back(t34);
+ checkpoints.push_back(t35);
+ checkpoints.push_back(t36);
+
+ ggml_build_forward_expand(gf, t36);
+
+ if (enable_checkpointing) {
+ ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
+ } else {
+ *gb = *gf;
+ ggml_build_backward_expand(ctx, gf, gb, true);
+ }
+
+ if (alloc) {
+ // make sure some tensors are not reallocated by inserting new temporary nodes depending on them
+ int n_leafs_before = gb->n_leafs;
+ int n_nodes_before = gb->n_nodes;
+ struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
+ // output tensors
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
+ // input gradient
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
+ GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad));
+ ggml_allocr_alloc(alloc, t36->grad);
+ // gradient tensors (will be set to zero by ggml_graph_reset)
+ // pinning these produces large unnecessary memory overhead, which will be resolved by PR 2632
+ for (int i = 0; i < gf->n_nodes; ++i) {
+ if (!gf->grads[i]) continue;
+ if (gf->grads[i]->data == NULL && !ggml_is_view(gf->grads[i])) {
+ ggml_allocr_alloc(alloc, gf->grads[i]);
+ }
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, gf->grads[i], one));
+ }
+ // allocating checkpoints in one block to reduce memory fragmentation
+ // note: they will be freed in reverse order
+ for (int i = 0; i < (int) checkpoints.size(); ++i) {
+ if (checkpoints[i]->data == NULL && !ggml_is_view(checkpoints[i])) {
+ ggml_allocr_alloc(alloc, checkpoints[i]);
+ }
+ }
- struct ggml_tensor * back_layer_inp = t31;
- struct ggml_tensor * grad_layer_inp = NULL;
+ //int n_leafs_after = gb->n_leafs;
+ //int n_nodes_after = gb->n_nodes;
- for (int k = 0; k < n_layer; ++k) {
- int il = n_layer-1-k;
- struct my_llama_layer & layer = model->layers[il];
+ ggml_allocr_alloc_graph(alloc, gb);
- struct ggml_tensor * t02 = t02L[il];
- struct ggml_tensor * t03 = t03L[il];
- struct ggml_tensor * t04 = t04L[il];
- struct ggml_tensor * t05 = t05L[il];
- struct ggml_tensor * t06 = t06L[il];
- struct ggml_tensor * t07 = t07L[il];
- struct ggml_tensor * t08 = t08L[il];
- struct ggml_tensor * t09 = t09L[il];
- struct ggml_tensor * t10 = t10L[il];
- struct ggml_tensor * t11 = t11L[il];
- struct ggml_tensor * t12 = t12L[il];
- struct ggml_tensor * t13 = t13L[il];
- struct ggml_tensor * t14 = t14L[il];
- struct ggml_tensor * t15 = t15L[il];
- struct ggml_tensor * t16 = t16L[il];
- struct ggml_tensor * t17 = t17L[il];
- struct ggml_tensor * t18 = t18L[il];
- struct ggml_tensor * t19 = t19L[il];
- struct ggml_tensor * t20 = t20L[il];
- struct ggml_tensor * t21 = t21L[il];
- struct ggml_tensor * t22 = t22L[il];
- struct ggml_tensor * t23 = t23L[il];
- struct ggml_tensor * t24 = t24L[il];
- struct ggml_tensor * t25 = t25L[il];
- struct ggml_tensor * t26 = t26L[il];
- struct ggml_tensor * t27 = t27L[il];
- struct ggml_tensor * t28 = t28L[il];
- struct ggml_tensor * t29 = t29L[il];
- struct ggml_tensor * t30 = t30L[il];
-
- clr_buf(0);
- use_buf(0);
- t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
- if (grad_layer_inp) {
- t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
+ // remove the additional nodes and leafs
+ for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
+ gb->leafs[i] = NULL;
}
- clr_buf(1);
- t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch);
- t28->grad = expand(gb, ggml_out_prod(ctx0, layer.w2, ggml_transpose(ctx0, t29->grad))); assert_shape_2d(t28->grad, n_ff, N*n_batch);
- t27->grad = expand(gb, ggml_mul(ctx0, t28->grad, t25)); assert_shape_2d(t27->grad, n_ff, N*n_batch);
- t26->grad = expand(gb, ggml_silu_back(ctx0, t26, t27->grad)); assert_shape_2d(t26->grad, n_ff, N*n_batch);
- t25->grad = expand(gb, ggml_mul(ctx0, t28->grad, t27)); assert_shape_2d(t25->grad, n_ff, N*n_batch);
- t24->grad = expand(gb, ggml_add_inplace(ctx0,
- ggml_out_prod(ctx0, layer.w1, ggml_transpose(ctx0, t26->grad)),
- ggml_out_prod(ctx0, layer.w3, ggml_transpose(ctx0, t25->grad)))); assert_shape_2d(t24->grad, n_embd, N*n_batch);
- t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
- t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
- use_buf(1);
- t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
- grad_layer_inp = t21;
- use_buf(0);
- t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
- t19->grad = expand(gb, ggml_out_prod(ctx0, layer.wo, ggml_transpose(ctx0, t20->grad))); assert_shape_2d(t19->grad, n_embd, N*n_batch);
- t18->grad = expand(gb, ggml_reshape_4d(ctx0, t19->grad, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t18->grad, n_embd/n_head, n_head, N, n_batch);
- t17->grad = t18->grad; assert_shape_4d(t17->grad, n_embd/n_head, n_head, N, n_batch);
- t16->grad = expand(gb, ggml_permute(ctx0, t17->grad, 0, 2, 1, 3)); assert_shape_4d(t16->grad, n_embd/n_head, N, n_head, n_batch);
- struct ggml_tensor * flash_attn = expand(gb, ggml_flash_attn_back(ctx0, t13, t14, t15, t16->grad, true)); assert_shape_4d(flash_attn, n_embd/n_head, N*3, n_head, n_batch);
- t15->grad = expand(gb, view__v(flash_attn)); assert_shape_4d(t15->grad, N, n_embd/n_head, n_head, n_batch);
- t14->grad = expand(gb, view__k(flash_attn)); assert_shape_4d(t14->grad, n_embd/n_head, N, n_head, n_batch);
- t13->grad = expand(gb, view__q(flash_attn)); assert_shape_4d(t13->grad, n_embd/n_head, N, n_head, n_batch);
- t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
- t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
- t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
- t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
- t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
- t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
- t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
- t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
- t04->grad = expand(gb, ggml_add_inplace(ctx0,
- ggml_add_inplace(ctx0,
- ggml_out_prod(ctx0, layer.wv, t11->grad),
- ggml_out_prod(ctx0, layer.wk, ggml_transpose(ctx0, t08->grad))),
- ggml_out_prod(ctx0, layer.wq, ggml_transpose(ctx0, t05->grad)))); assert_shape_2d(t04->grad, n_embd, N*n_batch);
- t03->grad = expand(gb, ggml_mul(ctx0, t04->grad, t02)); assert_shape_2d(t04->grad, n_embd, N*n_batch);
- use_buf(1);
- t02->grad = expand(gb, ggml_mul(ctx0, t04->grad, ggml_repeat(ctx0, layer.attention_norm, t02))); assert_shape_2d(t02->grad, n_embd, N*n_batch);
- back_layer_inp = t02;
- // use_buf(0);
-
- use_buf(-1);
- layer.attention_norm->grad = expand(gb, add_or_set(layer.attention_norm->grad, ggml_repeat_back(ctx0, t03->grad, layer.attention_norm))); assert_shape_1d(layer.attention_norm->grad, n_embd);
- layer.wq->grad = expand(gb, add_or_set(layer.wq->grad, ggml_out_prod(ctx0, t04, t05->grad))); assert_shape_2d(layer.wq->grad, n_embd, n_embd);
- layer.wk->grad = expand(gb, add_or_set(layer.wk->grad, ggml_out_prod(ctx0, t04, t08->grad))); assert_shape_2d(layer.wk->grad, n_embd, n_embd);
- layer.wv->grad = expand(gb, add_or_set(layer.wv->grad, ggml_out_prod(ctx0, t04, ggml_transpose(ctx0, t11->grad)))); assert_shape_2d(layer.wv->grad, n_embd, n_embd);
- layer.wo->grad = expand(gb, add_or_set(layer.wo->grad, ggml_out_prod(ctx0, t19, t20->grad))); assert_shape_2d(layer.wo->grad, n_embd, n_embd);
- layer.ffn_norm->grad = expand(gb, add_or_set(layer.ffn_norm->grad, ggml_repeat_back(ctx0, t23->grad, layer.ffn_norm))); assert_shape_1d(layer.ffn_norm->grad, n_embd);
- layer.w1->grad = expand(gb, add_or_set(layer.w1->grad, ggml_out_prod(ctx0, t24, t26->grad))); assert_shape_2d(layer.w1->grad, n_embd, n_ff);
- layer.w2->grad = expand(gb, add_or_set(layer.w2->grad, ggml_out_prod(ctx0, t28, t29->grad))); assert_shape_2d(layer.w2->grad, n_ff, n_embd);
- layer.w3->grad = expand(gb, add_or_set(layer.w3->grad, ggml_out_prod(ctx0, t24, t25->grad))); assert_shape_2d(layer.w3->grad, n_embd, n_ff);
- // use_buf(0);
+ for (int i = n_nodes_before; i < gb->n_nodes; ++i) {
+ gb->nodes[i] = NULL;
+ }
+ gb->n_leafs = n_leafs_before;
+ gb->n_nodes = n_nodes_before;
}
- clr_buf(0);
- use_buf(0);
- t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
- use_buf(-1);
- model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
- // clr_buf(1);
- // clr_buf(0);
*logits = t35;
-
- if (track_max_mem) {
- printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]);
- printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]);
- }
-
- // now that all grads are created, set the graph leafs and grads
- graph_set_leafs_grads(gf);
- graph_set_leafs_grads(gb);
-
return t36;
}
}
}
-
-void print_token(struct llama_context * ctx, llama_token token) {
- printf("%s", llama_token_to_piece(ctx, token).c_str());
-}
-
-void print_tokens(struct llama_context* ctx, struct ggml_tensor * tokens) {
- for (int i=0; i<tokens->ne[0]; ++i) {
- int token = ggml_get_i32_1d(tokens, i);
- print_token(ctx, token);
- }
-}
-
-void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens) {
- for (int i1=0; i1<tokens->ne[1]; ++i1) {
- //int num_newline = 0;
- for (int i0=0; i0<tokens->ne[0]; ++i0) {
- int token = get_i32_2d(tokens, i0, i1);
- print_token(ctx, token);
- // bool isnl = (token == llama_token_nl());
- // if (isnl) {
- // ++num_newline;
- // }
- // if (isnl) {
- // if (num_newline < 2) {
- // print_token(ctx, token);
- // } else {
- // printf("\\n");
- // }
- // } else {
- // print_token(ctx, token);
- // }
- }
- printf("\n--\n");
- }
-}
-
void get_example_targets(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
int n_tokens = tokens_input->ne[0];
int n_vocab = target_logits->ne[0];
ggml_set_f32(target_logits, -1.0f/n_vocab);
ggml_set_f32(target_probs, 0.0f);
+ // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
for (int k=0; k<n_batch; ++k) {
// printf("%s: batch %d\n", __func__, k);
- size_t sample = train_samples[(example_id*n_batch + k) % n_train_samples];
+ size_t sample_idx = (example_id*n_batch + k) % n_train_samples;
+ size_t sample = train_samples[sample_idx];
+ // printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
GGML_ASSERT(sample+n_tokens-1 < n_train_data);
set_i32_2d(tokens_input, 0, k, llama_token_bos(lctx));
for (int i=1; i<n_tokens+1; ++i) {
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
- // print_token(lctx, token);
set_f32_3d(target_logits, token, i-1, k, +1.0f);
set_f32_3d(target_probs, token, i-1, k, +1.0f);
if (i<n_tokens) {
set_i32_2d(tokens_input, i, k, token);
}
}
- // printf("\n=\n");
- // for (int i=0; i<n_tokens; ++i) {
- // int token = get_i32_2d(tokens_input, i, k);
- // print_token(lctx, token);
- // }
- // printf("\n-\n");
}
}
-void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs, int n_shift) {
- int n_tokens = tokens_input->ne[0];
- int n_vocab = target_logits->ne[0];
- for (int i=0; i<n_tokens-n_shift; ++i) {
- ggml_set_i32_1d(tokens_input, i, ggml_get_i32_1d(tokens_input, i + n_shift));
- for (int k=0; k<n_vocab; ++k) {
- ggml_set_f32_1d(target_logits, i*n_vocab + k, ggml_get_f32_1d(target_logits, (i + n_shift)*n_vocab + k));
- ggml_set_f32_1d(target_probs, i*n_vocab + k, ggml_get_f32_1d(target_probs, (i + n_shift)*n_vocab + k));
- }
- }
-}
-
-struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * target) {
- return ggml_sum(ctx, ggml_sqr(ctx, ggml_sub(ctx, target, a)));
-}
-
-struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * probs) {
- return ggml_cross_entropy_loss(ctx, a, probs);
-}
-
#ifdef __GNUC__
#ifdef __MINGW32__
__attribute__((format(gnu_printf, 1, 2)))
return std::string(buf.data(), size);
}
-struct llama_file {
- // use FILE * so we don't have to re-open the file to mmap
- FILE * fp;
- size_t size;
-
- llama_file(const char * fname, const char * mode) {
- fp = std::fopen(fname, mode);
- if (fp == NULL) {
- size = 0;
- } else {
- seek(0, SEEK_END);
- size = tell();
- seek(0, SEEK_SET);
- }
+int tokenize_file(struct llama_context * lctx, const char * filename, std::vector<llama_token>& out) {
+ FILE * fp = std::fopen(filename, "rb");
+ if (fp == NULL) {
+ return 0;
}
- size_t tell() const {
#ifdef _WIN32
- __int64 ret = _ftelli64(fp);
+ GGML_ASSERT(_fseeki64(fp, (__int64) 0, SEEK_END) == 0);
#else
- long ret = std::ftell(fp);
+ GGML_ASSERT(std::fseek(fp, (long) 0, SEEK_END) == 0);
#endif
- GGML_ASSERT(ret != -1); // this really shouldn't fail
- return (size_t) ret;
- }
- void seek(size_t offset, int whence) {
+ size_t size = 0;
#ifdef _WIN32
- int ret = _fseeki64(fp, (__int64) offset, whence);
+ __int64 ret = _ftelli64(fp);
+ size = ret;
#else
- int ret = std::fseek(fp, (long) offset, whence);
+ long ret = std::ftell(fp);
+ size = ret;
#endif
- GGML_ASSERT(ret == 0); // same
- }
-
- void read_raw(void * ptr, size_t size) {
- if (size == 0) {
- return;
- }
- errno = 0;
- std::size_t ret = std::fread(ptr, size, 1, fp);
- if (ferror(fp)) {
- throw std::runtime_error(format("read error: %s", strerror(errno)));
- }
- if (ret != 1) {
- throw std::runtime_error(std::string("unexpectedly reached end of file"));
- }
- }
-
- std::uint32_t read_u32() {
- std::uint32_t ret;
- read_raw(&ret, sizeof(ret));
- return ret;
- }
- std::string read_string(std::uint32_t len) {
- std::vector<char> chars(len);
- read_raw(chars.data(), len);
- return std::string(chars.data(), len);
- }
+#ifdef _WIN32
+ GGML_ASSERT(_fseeki64(fp, (__int64) 0, SEEK_SET) == 0);
+#else
+ GGML_ASSERT(std::fseek(fp, (long) 0, SEEK_SET) == 0);
+#endif
- void write_raw(const void * ptr, size_t size) {
- if (size == 0) {
- return;
- }
- errno = 0;
- size_t ret = std::fwrite(ptr, size, 1, fp);
- if (ret != 1) {
- throw std::runtime_error(format("write error: %s", strerror(errno)));
- }
- }
+ std::vector<char> buf;
+ buf.resize(size+1);
+ out.resize(size+1);
- void write_u32(std::uint32_t val) {
- write_raw(&val, sizeof(val));
+ if (std::fread(buf.data(), size, 1, fp) != 1) {
+ throw std::runtime_error(std::string("unexpectedly reached end of file"));
}
-
- ~llama_file() {
- if (fp) {
- std::fclose(fp);
- }
+ if (ferror(fp)) {
+ throw std::runtime_error(format("read error: %s", strerror(errno)));
}
-};
-
-int tokenize_file(struct llama_context * lctx, const char * filename, std::vector<llama_token>& out) {
- struct llama_file f(filename, "rb");
-
- std::vector<char> buf;
- buf.resize(f.size+1);
- f.read_raw(buf.data(), f.size);
- buf[f.size] = '\0';
+ buf[size] = '\0';
int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
if (n_tokens < 0) {
out.resize(-n_tokens);
- llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
+ n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
}
+ GGML_ASSERT(n_tokens >= 0);
+ out.resize(n_tokens);
bool verify = false;
if (verify) {
});
}
-struct my_llama_sampler_params {
- float temp = 0.0f; // <= 0.0 disabled
- int top_k = 20; // <= 0 to use vocab size
- float top_p = 0.95f; // 1.0 = disabled
- float tfs_z = 1.00f; // 1.0 = disabled
- float typical_p = 1.00f; // 1.0 = disabled
- int repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
- float repeat_penalty = 1.0f; // 1.0 = disabled
- float alpha_presence = 0.0f; // 0.0 = disabled
- float alpha_frequency = 0.0f; // 0.0 = disabled
- int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
- float mirostat_tau = 5.00f; // target entropy
- float mirostat_eta = 0.10f; // learning rate
- bool penalize_nl = true; // consider newlines as a repeatable token
-};
+#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
+{ \
+ const std::string skey(key); \
+ const int kid = gguf_find_key(ctx, skey.c_str()); \
+ if (kid >= 0) { \
+ enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
+ if (ktype != (type)) { \
+ throw std::runtime_error(format("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype))); \
+ } \
+ (dst) = func(ctx, kid); \
+ } else if (req) { \
+ throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \
+ } \
+}
-struct my_llama_sampler {
- struct llama_context * ctx = NULL;
- my_llama_sampler_params params;
- int n_vocab = 0;
- int n_ctx = 0;
+bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) {
+ GGML_ASSERT(a != NULL);
+ GGML_ASSERT(b != NULL);
+ GGML_ASSERT(a->type == b->type);
+ GGML_ASSERT(ggml_are_same_shape(a, b));
+ GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b));
- float mirostat_mu;
+ return true;
+}
- std::vector<llama_token_data> candidates;
- llama_token_data_array candidates_p;
+void read_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) {
+ if (dst == NULL) {
+ return;
+ }
+ struct ggml_tensor * t = ggml_get_tensor(ctx, name);
+ GGML_ASSERT(are_same_layout(dst, t));
+ memcpy(dst->data, t->data, ggml_nbytes(t));
-};
+ if (strlen(ggml_get_name(dst)) == 0) {
+ ggml_set_name(dst, name);
+ }
+}
-void init_sampler(struct my_llama_sampler * sampler, struct llama_context * ctx) {
- sampler->ctx = ctx;
- sampler->n_vocab = llama_n_vocab(sampler->ctx);
- sampler->n_ctx = llama_n_ctx(sampler->ctx);
- sampler->mirostat_mu = 2.0f * sampler->params.mirostat_tau;
+void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
+ // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
+
+ uint32_t file_version;
+ GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION);
+ GGML_ASSERT(file_version == 0);
+
+ GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
+ GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
+ GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
+
+ uint64_t nx;
+ GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT);
+ opt->nx = (size_t) nx;
+
+ // don't call ggml_opt_init until optimizer type and optimizer specific parameters are know
+
+ std::string opt_type;
+ GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
+ if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
+ opt->params.type = GGML_OPT_ADAM;
+
+ GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
+ GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
+ GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT);
+
+ GGML_ASSERT(opt->ctx != NULL);
+ ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
+
+ read_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
+ read_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
+ read_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
+ } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
+ opt->params.type = GGML_OPT_LBFGS;
+
+ GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
+ GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
+ GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP);
+ GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J);
+ GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K);
+ GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END);
+ GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT);
+
+ GGML_ASSERT(opt->ctx != NULL);
+ ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
+
+ read_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
+ read_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
+ read_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
+ read_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
+ read_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
+ read_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
+ read_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
+ read_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
+ read_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
+ read_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
+ } else {
+ throw std::runtime_error("unknown optimizer type\n");
+ }
}
-llama_token sample(struct my_llama_sampler * sampler, float * logits, const llama_token * last_tokens, int n_last_tokens) {
- GGML_ASSERT(sampler->ctx != NULL);
+void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) {
+ gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0);
+ gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past);
+ gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx);
+ gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
+ gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
- struct llama_context * ctx = sampler->ctx;
+ switch (opt->params.type) {
+ case GGML_OPT_ADAM:
+ {
+ gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
+ gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best);
+ gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev);
+ gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement);
+
+ ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
+ ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
+ if (opt->adam.pf) {
+ ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
+ }
- sampler->candidates.resize(sampler->n_vocab);
- for (llama_token token_id = 0; token_id < sampler->n_vocab; ++token_id) {
- sampler->candidates[token_id].id = token_id;
- sampler->candidates[token_id].logit = logits[token_id];
- sampler->candidates[token_id].p = 0.0;
+ gguf_add_tensor(fctx, opt->adam.m);
+ gguf_add_tensor(fctx, opt->adam.v);
+ if (opt->adam.pf) {
+ gguf_add_tensor(fctx, opt->adam.pf);
+ }
+ } break;
+ case GGML_OPT_LBFGS:
+ {
+ gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
+ gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);
+ gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best);
+ gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step);
+ gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j);
+ gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k);
+ gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end);
+ gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement);
+
+ ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
+ ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
+ ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
+ ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
+ ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
+ if (opt->lbfgs.pf) {
+ ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
+ }
+ ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
+ ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
+ ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
+ ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
+
+ gguf_add_tensor(fctx, opt->lbfgs.x);
+ gguf_add_tensor(fctx, opt->lbfgs.xp);
+ gguf_add_tensor(fctx, opt->lbfgs.g);
+ gguf_add_tensor(fctx, opt->lbfgs.gp);
+ gguf_add_tensor(fctx, opt->lbfgs.d);
+ if (opt->lbfgs.pf) {
+ gguf_add_tensor(fctx, opt->lbfgs.pf);
+ }
+ gguf_add_tensor(fctx, opt->lbfgs.lmal);
+ gguf_add_tensor(fctx, opt->lbfgs.lmys);
+ gguf_add_tensor(fctx, opt->lbfgs.lms);
+ gguf_add_tensor(fctx, opt->lbfgs.lmy);
+ } break;
}
+}
+
+void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model) {
+ // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
+ std::string arch;
- llama_token_data_array * candidates_p = & sampler->candidates_p;
+ std::vector<char> keybuf;
+ keybuf.resize(512);
+ auto kv = [&arch, &keybuf](const char * key) -> const char * {
+ snprintf(keybuf.data(), keybuf.size(), key, arch.c_str());
+ return keybuf.data();
+ };
- candidates_p->data = sampler->candidates.data();
- candidates_p->size = sampler->candidates.size();
- candidates_p->sorted = false;
+ std::vector<char> tn_buf;
+ tn_buf.resize(GGML_MAX_NAME);
+ auto tn = [&tn_buf](const char * key) -> const char * {
+ snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key);
+ return tn_buf.data();
+ };
+ auto tni = [&tn_buf](const char * key, int bid) -> const char * {
+ snprintf(tn_buf.data(), tn_buf.size(), key, bid);
+ std::string s = tn_buf.data();
+ snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str());
+ return tn_buf.data();
+ };
- const auto params = sampler->params;
+ GGUF_GET_KEY(fctx, arch, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_GENERAL_ARCHITECTURE);
+ GGML_ASSERT(arch == "llama");
- // Apply penalties
- const float nl_logit = logits[llama_token_nl(ctx)];
+ uint32_t ftype_u;
+ GGUF_GET_KEY(fctx, ftype_u, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_GENERAL_FILE_TYPE);
+ GGML_ASSERT((enum llama_ftype) ftype_u == LLAMA_FTYPE_ALL_F32);
- const int n_last = std::min(std::min(n_last_tokens, params.repeat_last_n), sampler->n_ctx);
+ // n_ctx was not saved in earlier checkpoint file versions, so we make it optional here
+ GGUF_GET_KEY(fctx, model->hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH));
- llama_sample_repetition_penalty(
- ctx,
- candidates_p,
- last_tokens + n_last_tokens - n_last,
- n_last,
- params.repeat_penalty);
- llama_sample_frequency_and_presence_penalties(
- ctx,
- candidates_p,
- last_tokens + n_last_tokens - n_last,
- n_last,
- params.alpha_frequency,
- params.alpha_presence);
+ GGUF_GET_KEY(fctx, model->hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
+ GGUF_GET_KEY(fctx, model->hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
+ GGUF_GET_KEY(fctx, model->hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
+ GGUF_GET_KEY(fctx, model->hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
- if (!params.penalize_nl) {
- logits[llama_token_nl(ctx)] = nl_logit;
- }
+ model->hparams.n_rot = model->hparams.n_embd / model->hparams.n_head;
+ GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
- llama_token token = 0;
- if (params.temp <= 0) {
- // Greedy sampling
- token = llama_sample_token_greedy(ctx, candidates_p);
- } else {
- if (params.mirostat == 1) {
- int mirostat_m = 100;
- llama_sample_temperature(ctx, candidates_p, params.temp);
- token = llama_sample_token_mirostat(ctx, candidates_p, params.mirostat_tau, params.mirostat_eta, mirostat_m, &sampler->mirostat_mu);
- } else if (params.mirostat == 2) {
- llama_sample_temperature(ctx, candidates_p, params.temp);
- token = llama_sample_token_mirostat_v2(ctx, candidates_p, params.mirostat_tau, params.mirostat_eta, &sampler->mirostat_mu);
- } else {
- // Temperature sampling
- llama_sample_top_k (ctx, candidates_p, params.top_k, 1);
- llama_sample_tail_free (ctx, candidates_p, params.tfs_z, 1);
- llama_sample_typical (ctx, candidates_p, params.typical_p, 1);
-
- llama_sample_top_p (ctx, candidates_p, params.top_p, 1);
- llama_sample_temperature (ctx, candidates_p, params.temp);
- token = llama_sample_token(ctx, candidates_p);
- }
+ float rope_freq_scale = 1.0f;
+ GGUF_GET_KEY(fctx, model->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
+ GGUF_GET_KEY(fctx, model->hparams.rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
+ GGUF_GET_KEY(fctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
+ if (rope_freq_scale != 1.0f) {
+ model->hparams.rope_freq_scale = 1.0f / rope_freq_scale;
}
- return token;
-}
-void set_logits_masked(struct ggml_tensor * logits, std::vector<bool>& mask, float value) {
- GGML_ASSERT(logits->ne[0] == (int64_t) mask.size());
- for (int i2 = 0; i2 < logits->ne[2]; ++i2) {
- for (int i1 = 0; i1 < logits->ne[1]; ++i1) {
- for (int i0 = 0; i0 < logits->ne[0]; ++i0) {
- if (!mask[i0]) continue;
- float * ptr = (float *) ((char *) logits->data + i2*logits->nb[2] + i1*logits->nb[1] + i0*logits->nb[0]);
- *ptr = value;
- }
- }
- }
-}
+ init_model(model);
-void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
- if (tensor == NULL) {
- file->write_u32(0);
- file->write_u32(0);
- file->write_u32(GGML_TYPE_F32);
- file->seek((0-file->tell()) & 31, SEEK_CUR);
- return;
+ read_tensor_by_name(model->tok_embeddings, f_ggml_ctx, tn(LLM_TENSOR_TOKEN_EMBD));
+ read_tensor_by_name(model->norm, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT_NORM));
+ read_tensor_by_name(model->output, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT));
+
+ for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
+ auto & layer = model->layers[i];
+
+ read_tensor_by_name(layer.attention_norm, f_ggml_ctx, tni(LLM_TENSOR_ATTN_NORM, i));
+ read_tensor_by_name(layer.wq, f_ggml_ctx, tni(LLM_TENSOR_ATTN_Q, i));
+ read_tensor_by_name(layer.wk, f_ggml_ctx, tni(LLM_TENSOR_ATTN_K, i));
+ read_tensor_by_name(layer.wv, f_ggml_ctx, tni(LLM_TENSOR_ATTN_V, i));
+ read_tensor_by_name(layer.wo, f_ggml_ctx, tni(LLM_TENSOR_ATTN_OUT, i));
+ read_tensor_by_name(layer.ffn_norm, f_ggml_ctx, tni(LLM_TENSOR_FFN_NORM, i));
+ read_tensor_by_name(layer.w1, f_ggml_ctx, tni(LLM_TENSOR_FFN_GATE, i));
+ read_tensor_by_name(layer.w2, f_ggml_ctx, tni(LLM_TENSOR_FFN_DOWN, i));
+ read_tensor_by_name(layer.w3, f_ggml_ctx, tni(LLM_TENSOR_FFN_UP, i));
}
- const char * name = ggml_get_name(tensor);
- uint32_t name_len = strlen(name);
- uint32_t nd = tensor->n_dims;
- uint32_t ne[4] = { (uint32_t)tensor->ne[0],
- (uint32_t)tensor->ne[1],
- (uint32_t)tensor->ne[2],
- (uint32_t)tensor->ne[3] };
- file->write_u32(nd);
- file->write_u32(name_len);
- file->write_u32(tensor->type);
- file->write_raw(ne, sizeof(ne[0]) * nd);
- file->write_raw(name, name_len);
- file->seek((0-file->tell()) & 31, SEEK_CUR);
- file->write_raw(tensor->data, ggml_nbytes(tensor));
}
-void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
- int32_t nd = file->read_u32();
- GGML_ASSERT(nd == tensor->n_dims);
+void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model) {
+ const char * arch = "llama";
+ enum llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
- uint32_t name_len = file->read_u32();
- enum ggml_type type = (enum ggml_type) file->read_u32();
- GGML_ASSERT(type == tensor->type);
+ std::vector<char> keybuf;
+ keybuf.resize(512);
+ auto kv = [arch, &keybuf](const char * key) -> const char * {
+ snprintf(keybuf.data(), keybuf.size(), key, arch);
+ return keybuf.data();
+ };
- uint32_t ne[4];
- file->read_raw(ne, sizeof(ne[0]) * nd);
- for (int i=0; i<nd; ++i) {
- GGML_ASSERT(ne[i] == tensor->ne[i]);
- }
+ // set arch
+ gguf_set_val_str(fctx, LLM_KV_GENERAL_ARCHITECTURE, arch);
+ gguf_set_val_u32(fctx, LLM_KV_GENERAL_FILE_TYPE, ftype);
- std::string name = file->read_string(name_len);
- GGML_ASSERT(strncmp(ggml_get_name(tensor), name.c_str(), sizeof(tensor->name)-1) == 0);
+ // set hparams
+ gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH), model->hparams.n_ctx );
+ gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH), model->hparams.n_embd );
+ gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH), model->hparams.n_ff );
+ gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head );
+ gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT), model->hparams.n_layer );
+ gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT), model->hparams.n_rot );
- file->seek((0-file->tell()) & 31, SEEK_CUR);
- file->read_raw(tensor->data, ggml_nbytes(tensor));
-}
+ gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), model->hparams.f_norm_rms_eps );
+ gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_FREQ_BASE), model->hparams.rope_freq_base ); // TODO load in llama.cpp
+ gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_SCALE_LINEAR), 1.0f / model->hparams.rope_freq_scale );
-void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt) {
- const uint32_t version = 0;
- GGML_ASSERT(opt->nx >= 0);
- GGML_ASSERT(opt->iter >= 0);
- file->write_u32(version);
- file->write_raw(&opt->params, sizeof(opt->params));
- file->write_raw(&opt->nx, sizeof(opt->nx));
- file->write_raw(&opt->iter, sizeof(opt->iter));
- file->write_u32((uint32_t) opt->just_initialized);
- switch (opt->params.type) {
- case GGML_OPT_ADAM:
- {
- GGML_ASSERT(opt->adam.x != NULL);
- write_tensor(file, opt->adam.x);
- write_tensor(file, opt->adam.g1);
- write_tensor(file, opt->adam.g2);
- write_tensor(file, opt->adam.m);
- write_tensor(file, opt->adam.v);
- write_tensor(file, opt->adam.mh);
- write_tensor(file, opt->adam.vh);
- write_tensor(file, opt->adam.pf);
- file->write_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best));
- file->write_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev));
- file->write_raw(&opt->adam.n_no_improvement, sizeof(opt->adam.n_no_improvement));
- } break;
- case GGML_OPT_LBFGS:
- {
- GGML_ASSERT(opt->adam.x != NULL);
- write_tensor(file, opt->lbfgs.x);
- write_tensor(file, opt->lbfgs.xp);
- write_tensor(file, opt->lbfgs.g);
- write_tensor(file, opt->lbfgs.gp);
- write_tensor(file, opt->lbfgs.d);
- write_tensor(file, opt->lbfgs.pf);
- write_tensor(file, opt->lbfgs.lmal);
- write_tensor(file, opt->lbfgs.lmys);
- write_tensor(file, opt->lbfgs.lms);
- write_tensor(file, opt->lbfgs.lmy);
- file->write_raw(&opt->lbfgs.fx_best, sizeof(opt->lbfgs.fx_best));
- file->write_raw(&opt->lbfgs.step, sizeof(opt->lbfgs.step));
- file->write_raw(&opt->lbfgs.j, sizeof(opt->lbfgs.j));
- file->write_raw(&opt->lbfgs.k, sizeof(opt->lbfgs.k));
- file->write_raw(&opt->lbfgs.end, sizeof(opt->lbfgs.end));
- file->write_raw(&opt->lbfgs.n_no_improvement, sizeof(opt->lbfgs.n_no_improvement));
- } break;
- }
-}
+ // set vocab by copying from vocab_model gguf file
+ {
+ struct gguf_init_params params = {
+ /*.no_alloc = */ false,
+ /*.ctx = */ NULL,
+ };
+ struct gguf_context * vctx = gguf_init_from_file(fn_vocab_model, params);
+
+ const int token_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_LIST));
+ if (token_idx == -1) {
+ throw std::runtime_error("cannot find tokenizer vocab in model file\n");
+ }
+ const uint32_t n_vocab = gguf_get_arr_n(vctx, token_idx);
-void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
- uint32_t version = file->read_u32();
- GGML_ASSERT(version == 0);
+ const int score_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_SCORES));
+ if (score_idx == -1) {
+ throw std::runtime_error("cannot find tokenizer scores in model file\n");
+ }
- file->read_raw(&opt->params, sizeof(opt->params));
- file->read_raw(&opt->nx, sizeof(opt->nx));
- ggml_opt_init(ctx, opt, opt->params, opt->nx);
+ const float * scores = (const float * ) gguf_get_arr_data(vctx, score_idx);
- file->read_raw(&opt->iter, sizeof(opt->iter));
- opt->just_initialized = (bool) file->read_u32();
+ const int toktype_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE));
+ if (toktype_idx == -1) {
+ throw std::runtime_error("cannot find token type list in GGUF file\n");
+ }
- switch (opt->params.type) {
- case GGML_OPT_ADAM:
- {
- read_tensor(file, opt->adam.x);
- read_tensor(file, opt->adam.g1);
- read_tensor(file, opt->adam.g2);
- read_tensor(file, opt->adam.m);
- read_tensor(file, opt->adam.v);
- read_tensor(file, opt->adam.mh);
- read_tensor(file, opt->adam.vh);
- if (opt->adam.pf) { read_tensor(file, opt->adam.pf); }
- file->read_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best));
- file->read_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev));
- file->read_raw(&opt->adam.n_no_improvement, sizeof(opt->adam.n_no_improvement));
- } break;
- case GGML_OPT_LBFGS:
- {
- GGML_ASSERT(opt->adam.x != NULL);
- read_tensor(file, opt->lbfgs.x);
- read_tensor(file, opt->lbfgs.xp);
- read_tensor(file, opt->lbfgs.g);
- read_tensor(file, opt->lbfgs.gp);
- read_tensor(file, opt->lbfgs.d);
- if (opt->lbfgs.pf) { read_tensor(file, opt->lbfgs.pf); }
- read_tensor(file, opt->lbfgs.lmal);
- read_tensor(file, opt->lbfgs.lmys);
- read_tensor(file, opt->lbfgs.lms);
- read_tensor(file, opt->lbfgs.lmy);
- file->read_raw(&opt->lbfgs.fx_best, sizeof(opt->lbfgs.fx_best));
- file->read_raw(&opt->lbfgs.step, sizeof(opt->lbfgs.step));
- file->read_raw(&opt->lbfgs.j, sizeof(opt->lbfgs.j));
- file->read_raw(&opt->lbfgs.k, sizeof(opt->lbfgs.k));
- file->read_raw(&opt->lbfgs.end, sizeof(opt->lbfgs.end));
- file->read_raw(&opt->lbfgs.n_no_improvement, sizeof(opt->lbfgs.n_no_improvement));
- } break;
- }
-}
+ const int * toktypes = (const int * ) gguf_get_arr_data(vctx, toktype_idx);
+
+ std::string tokenizer_name;
+ GGUF_GET_KEY(vctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
+
+ gguf_set_val_str(fctx, kv(LLM_KV_TOKENIZER_MODEL), tokenizer_name.c_str());
+ gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_SCORES), GGUF_TYPE_FLOAT32, scores, n_vocab);
+ gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE), GGUF_TYPE_INT32, toktypes, n_vocab);
+
+ int32_t special_bos_id = 1;
+ int32_t special_eos_id = 2;
+ int32_t special_unk_id = 0;
+ int32_t special_sep_id = -1;
+ int32_t special_pad_id = -1;
+ if (tokenizer_name == "llama") {
+ // default special tokens
+ special_bos_id = 1;
+ special_eos_id = 2;
+ special_unk_id = 0;
+ special_sep_id = -1;
+ special_pad_id = -1;
+ } else if (tokenizer_name == "gpt2") {
+ // read and copy bpe merges
+ const int merges_keyidx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_MERGES));
+ if (merges_keyidx == -1) {
+ throw std::runtime_error("cannot find tokenizer merges in model file\n");
+ }
-void save_checkpoint(struct my_llama_model * model, struct ggml_opt_context * opt, const char * filename) {
- struct llama_file file(filename, "wb");
- if (file.fp == NULL) {
- return;
- }
+ const int n_merges = gguf_get_arr_n(vctx, merges_keyidx);
+
+ std::vector<const char*> merges;
+ merges.resize(n_merges);
+ for (int i = 0; i < n_merges; i++) {
+ merges[i] = gguf_get_arr_str(vctx, merges_keyidx, i);
+ }
+ gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_MERGES), merges.data(), n_merges);
+
+ // default special tokens
+ special_bos_id = 11;
+ special_eos_id = 11;
+ special_unk_id = -1;
+ special_sep_id = -1;
+ special_pad_id = -1;
+ } else {
+ fprintf(stderr, "%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
+ fprintf(stderr, "%s: using default tokenizer: 'llama'", __func__);
+ }
+
+ std::vector<const char*> tokens;
+ tokens.resize(n_vocab);
+ for (uint32_t i = 0; i < n_vocab; i++) {
+ tokens[i] = gguf_get_arr_str(vctx, token_idx, i);
+ }
+ gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_LIST), tokens.data(), n_vocab);
+
+ GGUF_GET_KEY(vctx, special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID));
+ GGUF_GET_KEY(vctx, special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID));
+ GGUF_GET_KEY(vctx, special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
+ GGUF_GET_KEY(vctx, special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
+ GGUF_GET_KEY(vctx, special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
- const uint32_t magic = 'ggcp';
- const uint32_t version = 0;
-
- file.write_u32(magic);
- file.write_u32(version);
- file.write_u32(model->train_its);
- file.write_u32(model->train_samples);
- file.write_u32(model->train_tokens);
- file.write_u32(model->hparams.n_vocab);
- file.write_u32(model->hparams.n_embd);
- file.write_u32(model->hparams.n_mult);
- file.write_u32(model->hparams.n_head);
- file.write_u32(model->hparams.n_layer);
- file.write_u32(model->hparams.n_rot);
-
- write_tensor(&file, model->tok_embeddings);
- write_tensor(&file, model->norm);
- write_tensor(&file, model->output);
+ gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_BOS_ID), special_bos_id);
+ gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_EOS_ID), special_eos_id);
+ gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_UNK_ID), special_unk_id);
+ gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_SEP_ID), special_sep_id);
+ gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_PAD_ID), special_pad_id);
+
+ gguf_free(vctx);
+ }
+ // add tensors
+ gguf_add_tensor(fctx, model->tok_embeddings);
+ gguf_add_tensor(fctx, model->norm);
+ gguf_add_tensor(fctx, model->output);
for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
auto & layer = model->layers[i];
- write_tensor(&file, layer.attention_norm);
- write_tensor(&file, layer.wq);
- write_tensor(&file, layer.wk);
- write_tensor(&file, layer.wv);
- write_tensor(&file, layer.wo);
- write_tensor(&file, layer.ffn_norm);
- write_tensor(&file, layer.w1);
- write_tensor(&file, layer.w2);
- write_tensor(&file, layer.w3);
+
+ gguf_add_tensor(fctx, layer.attention_norm);
+ gguf_add_tensor(fctx, layer.wq);
+ gguf_add_tensor(fctx, layer.wk);
+ gguf_add_tensor(fctx, layer.wv);
+ gguf_add_tensor(fctx, layer.wo);
+ gguf_add_tensor(fctx, layer.ffn_norm);
+ gguf_add_tensor(fctx, layer.w1);
+ gguf_add_tensor(fctx, layer.w2);
+ gguf_add_tensor(fctx, layer.w3);
}
+}
+
+void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model) {
+ struct gguf_context * fctx = gguf_init_empty();
+
+ save_llama_model_gguf(fctx, fn_vocab_model, model);
- write_opt_context(&file, opt);
+ // write file
+ const bool only_meta = false;
+ gguf_write_to_file(fctx, filename, only_meta);
+ gguf_free(fctx);
}
-bool load_checkpoint(struct my_llama_model * model, struct ggml_opt_context * opt, const char * filename, bool init) {
- struct llama_file file(filename, "rb");
+void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct ggml_opt_context * opt) {
+ load_llama_model_gguf(fctx, f_ggml_ctx, model);
- uint32_t magic;
- uint32_t version;
+ uint32_t file_version;
+ GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
+ GGML_ASSERT(file_version == 0);
- uint32_t train_its = 0;
- uint32_t train_samples = 0;
- uint32_t train_tokens = 0;
+ GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
+ GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
+ GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
- if (file.fp) {
- printf("%s: Loading model from '%s'.\n", __func__, filename);
- magic = file.read_u32();
- GGML_ASSERT(magic == 'ggcp');
- version = file.read_u32();
- GGML_ASSERT(version == 0);
- train_its = file.read_u32();
- train_samples = file.read_u32();
- train_tokens = file.read_u32();
- model->hparams.n_vocab = file.read_u32();
- model->hparams.n_embd = file.read_u32();
- model->hparams.n_mult = file.read_u32();
- model->hparams.n_head = file.read_u32();
- model->hparams.n_layer = file.read_u32();
- model->hparams.n_rot = file.read_u32();
- print_params(&model->hparams);
- }
+ load_opt_context_gguf(fctx, f_ggml_ctx, opt);
+}
- if (init) {
- init_model(model);
- }
+void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) {
+ save_llama_model_gguf(fctx, fn_vocab_model, model);
- if (file.fp) {
- model->train_its = train_its;
- model->train_samples = train_samples;
- model->train_tokens = train_tokens;
- }
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 0);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_ITERATION_COUNT, model->train_its);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, model->train_samples);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_TOKEN_COUNT, model->train_tokens);
- printf("%s: Training iterations: %u.\n", __func__, model->train_its);
- printf("%s: Training samples: %u.\n", __func__, model->train_samples);
- printf("%s: Training tokens: %u.\n", __func__, model->train_tokens);
-
- if (file.fp) {
- read_tensor(&file, model->tok_embeddings);
- read_tensor(&file, model->norm);
- read_tensor(&file, model->output);
-
- for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
- auto & layer = model->layers[i];
-
- read_tensor(&file, layer.attention_norm);
- read_tensor(&file, layer.wq);
- read_tensor(&file, layer.wk);
- read_tensor(&file, layer.wv);
- read_tensor(&file, layer.wo);
- read_tensor(&file, layer.ffn_norm);
- read_tensor(&file, layer.w1);
- read_tensor(&file, layer.w2);
- read_tensor(&file, layer.w3);
- }
+ save_opt_context_gguf(fctx, opt);
+}
- read_opt_context(&file, model->ctx, opt);
+bool load_checkpoint_file(const char * filename, struct my_llama_model * model, struct ggml_opt_context * opt) {
+ struct ggml_context * f_ggml_ctx;
+ struct gguf_init_params params;
+ params.no_alloc = false;
+ params.ctx = &f_ggml_ctx;
+ struct gguf_context * fctx = gguf_init_from_file(filename, params);
+ if (fctx == NULL) {
+ return false;
}
- return (file.fp != NULL);
+ load_checkpoint_gguf(fctx, f_ggml_ctx, model, opt);
+
+ return true;
}
-void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * model, const char * filename) {
- struct llama_file file(filename, "wb");
- if (file.fp == NULL) {
- return;
- }
+void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) {
+ struct gguf_context * fctx = gguf_init_empty();
+
+ save_checkpoint_gguf(fctx, fn_vocab_model, model, opt);
-#pragma message("TODO: implement file saving using gguf")
- (void) vocab;
- (void) model;
-// // write_magic
-// file.write_u32(LLAMA_FILE_MAGIC); // magic
-// file.write_u32(LLAMA_FILE_VERSION); // version
-// // write_hparams
-// file.write_u32(model->hparams.n_vocab);
-// file.write_u32(model->hparams.n_embd);
-// file.write_u32(model->hparams.n_mult);
-// file.write_u32(model->hparams.n_head);
-// file.write_u32(model->hparams.n_layer);
-// file.write_u32(model->hparams.n_rot);
-// file.write_u32(LLAMA_FTYPE_ALL_F32);
-// // write_vocab
-// uint32_t n_vocab = model->hparams.n_vocab;
-// for (uint32_t i = 0; i < n_vocab; i++) {
-// const auto & token_data = vocab->id_to_token.at(i);
-// file.write_u32((uint32_t) token_data.tok.size());
-// file.write_raw(token_data.tok.data(), token_data.tok.size());
-// file.write_raw(&token_data.score, sizeof(token_data.score));
-// }
-// // write tensors
-// write_tensor(&file, model->tok_embeddings);
-// write_tensor(&file, model->norm);
-// write_tensor(&file, model->output);
-// for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
-// auto & layer = model->layers[i];
-//
-// write_tensor(&file, layer.attention_norm);
-// write_tensor(&file, layer.wq);
-// write_tensor(&file, layer.wk);
-// write_tensor(&file, layer.wv);
-// write_tensor(&file, layer.wo);
-// write_tensor(&file, layer.ffn_norm);
-// write_tensor(&file, layer.w1);
-// write_tensor(&file, layer.w2);
-// write_tensor(&file, layer.w3);
-// }
+ // write file
+ const bool only_meta = false;
+ gguf_write_to_file(fctx, filename, only_meta);
+ gguf_free(fctx);
}
-float cosine_decay(const int decay_steps, const float alpha, int step) {
+float cosine_decay(const int decay_steps, const float minimum, int step) {
if (step > decay_steps) {
step = decay_steps;
}
const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
- const float decay = (1 - alpha)*cosine_decay + alpha;
+ const float decay = (1 - minimum)*cosine_decay + minimum;
return decay;
}
-float cosine_decay_restart(int decay_steps, const float alpha, int step, float restart_step_mult) {
- while (step > decay_steps) {
- step -= decay_steps;
- decay_steps = (int) restart_step_mult * decay_steps;
+float cosine_decay_restart(int decay_steps, const float minimum, int step, float restart_step_mult, bool enable_restart) {
+ if (enable_restart) {
+ while (step > decay_steps) {
+ step -= decay_steps;
+ decay_steps = (int) restart_step_mult * decay_steps;
+ }
}
- return cosine_decay(decay_steps, alpha, step);
+ return cosine_decay(decay_steps, minimum, step);
}
struct train_params {
int n_ctx;
int n_embd;
- int n_mult;
int n_head;
int n_layer;
- int n_rotmax;
+ int n_ff;
int n_threads;
int n_batch;
int n_examples;
- int n_predict;
+
+ float f_norm_rms_eps;
+ float rope_freq_base;
+ float rope_freq_scale;
int print_info_interval;
- int print_details_interval;
bool samples_start_after_nl;
bool use_adam;
bool use_flash;
- bool use_scratch;
+ bool use_checkpointing;
+ bool use_alloc;
// only adam
int warmup;
int cos_decay_steps;
float cos_decay_restart;
- float cos_decay_alpha;
+ float cos_decay_min;
+ bool enable_restart;
+
+ int opt_past;
+ float opt_delta;
+ int opt_max_no_improvement;
int lbfgs_n_iter;
int adam_n_iter;
float adam_alpha;
+ float adam_min_alpha;
float adam_decay;
+ int adam_decay_min_ndim;
+ float adam_beta1;
+ float adam_beta2;
+ float adam_gclip;
+ float adam_eps_f;
int mem_model_gb;
int mem_compute_gb;
int mem_compute0_gb;
- int mem_compute1_gb;
};
struct train_params get_default_train_params() {
params.n_ctx = 128;
params.n_embd = 256;
- params.n_mult = 256;
params.n_head = 8;
params.n_layer = 16;
- params.n_rotmax = 64;
+ params.n_ff = 768;
params.n_threads = 6;
params.n_batch = 8;
- params.n_examples = 8;
- params.n_predict = 1024;
+ params.n_examples = 1;
+
+ params.f_norm_rms_eps = 1e-5;
+ params.rope_freq_base = 10000.0f;
+ params.rope_freq_scale = 1.0f;
params.print_info_interval = 1;
- params.print_details_interval = 2;
params.samples_start_after_nl = false;
params.use_adam = true;
params.use_flash = true;
- params.use_scratch = true;
+ params.use_checkpointing = true;
+ params.use_alloc = true;
+
+ params.opt_past = 0;
+ params.opt_delta = 1e-5f;
+ params.opt_max_no_improvement = 0;
// only adam
params.warmup = 100;
params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f;
- params.cos_decay_alpha = 0.0f;
-
- params.lbfgs_n_iter = 16;
- params.adam_n_iter = 16;
- params.adam_alpha = 1e-3f;
- params.adam_decay = 1e-3f;
-
- params.mem_model_gb = 2;
+ params.cos_decay_min = 0.1f;
+ params.enable_restart = false;
+
+ params.lbfgs_n_iter = 256;
+ params.adam_n_iter = 256;
+ params.adam_alpha = 1e-3f;
+ params.adam_min_alpha = 0;
+ params.adam_decay = 1e-1f;
+ params.adam_decay_min_ndim = 2;
+ params.adam_beta1 = 0.9f;
+ params.adam_beta2 = 0.999f;
+ params.adam_gclip = 1.0f;
+ params.adam_eps_f = 0.0f;
+
+ params.mem_model_gb = 2;
params.mem_compute_gb = 24;
params.mem_compute0_gb = 8;
- params.mem_compute1_gb = 2;
-
return params;
}
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
- fprintf(stderr, " --mult N Mult size used for new models, influences feedforward size. (default %d)\n", params->n_mult);
+ fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff);
fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head);
fprintf(stderr, " --layer N Number of layers for new models (default %d)\n", params->n_layer);
- fprintf(stderr, " --rotmax N Maximal number Rope dimensions for new models (default %d)\n", params->n_rotmax);
+ fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
+ fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
+ fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
- fprintf(stderr, " --predict N Number of tokens to generate after training (default %d)\n", params->n_predict);
fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
- fprintf(stderr, " --print-details-interval N Print details during training each N examples (default %d)\n", params->print_details_interval);
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
- fprintf(stderr, " --no-flash Don't use flash attention.\n");
+ fprintf(stderr, " --no-flash Don't use flash attention \n");
fprintf(stderr, " --use-flash Use flash attention (default)\n");
- fprintf(stderr, " --no-scratch Don't use scratch buffers\n");
- fprintf(stderr, " --use-scratch Use scratch buffers (default)\n");
- fprintf(stderr, " --warmup N Number of warmup steps (default %d)\n", params->warmup);
- fprintf(stderr, " --cos-decay-steps N Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
- fprintf(stderr, " --cos-decay-restart N Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
- fprintf(stderr, " --cos-decay-alpha N Cosine decay alpha (default %f)\n", params->cos_decay_alpha);
- fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
+ fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
+ fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
+ fprintf(stderr, " --no-alloc Don't use allocator\n");
+ fprintf(stderr, " --use-alloc Use allocator (default)\n");
+ fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
+ fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
+ fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
+ fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
+ fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
+ fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
+ fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
+ fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
+ fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
+ fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
+ fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
+ fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
+ fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
+ fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
+ fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
+ fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
- fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb);
- fprintf(stderr, " --mem-compute1 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute1_gb);
+ fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb);
fprintf(stderr, "\n");
}
break;
}
params->n_embd = std::stoi(argv[i]);
- } else if (arg == "--mult") {
+ } else if (arg == "--ff") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params->n_mult = std::stoi(argv[i]);
+ params->n_ff = std::stoi(argv[i]);
} else if (arg == "--head") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_layer = std::stoi(argv[i]);
- } else if (arg == "--rotmax") {
+ } else if (arg == "--norm-rms-eps") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params->n_rotmax = std::stoi(argv[i]);
- } else if (arg == "-t" || arg == "--threads") {
+ params->f_norm_rms_eps = std::stof(argv[i]);
+ } else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params->n_threads = std::stoi(argv[i]);
- } else if (arg == "-b" || arg == "--batch") {
+ params->rope_freq_base = std::stof(argv[i]);
+ } else if (arg == "--rope-freq-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params->n_batch = std::stoi(argv[i]);
- } else if (arg == "-n" || arg == "--examples") {
+ params->rope_freq_scale = std::stof(argv[i]);
+ } else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params->n_examples = std::stoi(argv[i]);
- } else if (arg == "--predict") {
+ params->n_threads = std::stoi(argv[i]);
+ } else if (arg == "-b" || arg == "--batch") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params->n_predict = std::stoi(argv[i]);
- } else if (arg == "--print-info-interval") {
+ params->n_batch = std::stoi(argv[i]);
+ } else if (arg == "-n" || arg == "--examples") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params->print_info_interval = std::stoi(argv[i]);
- } else if (arg == "--print-details-interval") {
+ params->n_examples = std::stoi(argv[i]);
+ } else if (arg == "--print-info-interval") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params->print_details_interval = std::stoi(argv[i]);
+ params->print_info_interval = std::stoi(argv[i]);
} else if (arg == "--samples-after-nl") {
params->samples_start_after_nl = true;
} else if (arg == "--use-lbfgs") {
params->use_flash = false;
} else if (arg == "--use-flash") {
params->use_flash = true;
- } else if (arg == "--no-scratch") {
- params->use_scratch = false;
- } else if (arg == "--use-scratch") {
- params->use_scratch = true;
+ } else if (arg == "--no-checkpointing") {
+ params->use_checkpointing = false;
+ } else if (arg == "--use-checkpointing") {
+ params->use_checkpointing = true;
+ } else if (arg == "--no-alloc") {
+ params->use_alloc = false;
+ } else if (arg == "--use-alloc") {
+ params->use_alloc = true;
} else if (arg == "--warmup") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_restart = std::stof(argv[i]);
- } else if (arg == "--cos-decay-alpha") {
+ } else if (arg == "--cos-decay-min") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params->cos_decay_alpha = std::stof(argv[i]);
- } else if (arg == "--lbfgs-iter") {
+ params->cos_decay_min = std::stof(argv[i]);
+ } else if (arg == "--enable-restart") {
+ params->enable_restart = true;
+ } else if (arg == "--disable-restart") {
+ params->enable_restart = false;
+ } else if (arg == "--opt-past") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params->lbfgs_n_iter = std::stoi(argv[i]);
+ params->opt_past = std::stoi(argv[i]);
+ } else if (arg == "--opt-delta") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->opt_delta = std::stof(argv[i]);
+ } else if (arg == "--opt-max-no-improvement") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->opt_max_no_improvement = std::stoi(argv[i]);
+ } else if (arg == "--adam-epsf") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->adam_eps_f = std::stof(argv[i]);
} else if (arg == "--adam-iter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_alpha = std::stof(argv[i]);
+ } else if (arg == "--adam-min-alpha") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->adam_min_alpha = std::stof(argv[i]);
} else if (arg == "--adam-decay") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_decay = std::stof(argv[i]);
+ } else if (arg == "--adam-decay-min-ndim") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->adam_decay_min_ndim = std::stoi(argv[i]);
+ } else if (arg == "--adam-beta1") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->adam_beta1 = std::stof(argv[i]);
+ } else if (arg == "--adam-beta2") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->adam_beta2 = std::stof(argv[i]);
+ } else if (arg == "--adam-gclip") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->adam_gclip = std::stof(argv[i]);
+ } else if (arg == "--lbfgs-iter") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->lbfgs_n_iter = std::stoi(argv[i]);
} else if (arg == "--mem-model") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->mem_compute0_gb = std::stoi(argv[i]);
- } else if (arg == "--mem-compute1") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->mem_compute1_gb = std::stoi(argv[i]);
} else if (arg == "-h" || arg == "--help") {
train_print_usage(argc, argv, &default_params);
exit(0);
return true;
}
+struct opt_callback_data {
+ struct train_params * params;
+ struct ggml_opt_context * opt;
+ struct llama_context * lctx;
+ llama_token * tokens_data;
+ size_t tokens_size;
+ int * samples_data;
+ size_t samples_size;
+ int shuffle_countdown;
+ struct ggml_tensor * tokens_input;
+ struct ggml_tensor * target_logits;
+ struct ggml_tensor * target_probs;
+};
+
+void opt_callback(void * vdata, float * sched) {
+ struct opt_callback_data * data = (struct opt_callback_data *) vdata;
+ struct train_params * params = data->params;
+ struct ggml_opt_context * opt = data->opt;
+ int n_batch = params->n_batch;
+
+ *sched = (opt->iter < params->warmup)
+ ? (float) opt->iter / (float) params->warmup
+ : cosine_decay_restart(
+ params->cos_decay_steps,
+ params->cos_decay_min,
+ opt->iter - params->warmup,
+ params->cos_decay_restart,
+ params->enable_restart);
+ float min_sched = params->adam_min_alpha / params->adam_alpha;
+ *sched = min_sched + *sched * (1.0f - min_sched);
+
+ int impr_plot = std::isnan(opt->loss_after) ? 0 : -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
+ printf("%s: iter=%*d, sched=%f loss0=%f loss=%f | improvement: %*d>\n", __func__, 6, opt->iter, *sched, opt->loss_before, opt->loss_after, impr_plot, (int)0);
+
+ if (data->shuffle_countdown < n_batch) {
+ printf("%s: reshuffle samples\n", __func__);
+ shuffle_ints(data->samples_data, data->samples_data + data->samples_size);
+ for (int i = 0; i < (int) data->samples_size; ++i) {
+ GGML_ASSERT(data->samples_data[i]+params->n_ctx-1 < (int) data->tokens_size);
+ }
+ data->shuffle_countdown = data->samples_size;
+ }
+
+ get_example_targets_batch(
+ data->lctx,
+ data->samples_data,
+ data->samples_size,
+ data->tokens_data,
+ data->tokens_size,
+ opt->iter,
+ data->tokens_input,
+ data->target_logits,
+ data->target_probs);
+
+ data->shuffle_countdown -= n_batch;
+}
+
int main(int argc, char ** argv) {
struct train_params params = get_default_train_params();
struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params);
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
- struct llama_vocab vocab;
- {
- const int n_vocab = llama_n_vocab(lctx);
- vocab.id_to_token.resize(n_vocab);
- for (int i=0; i<n_vocab; ++i) {
- vocab.id_to_token[i].text = llama_token_get_text(lctx, i);
- vocab.id_to_token[i].score = llama_token_get_score(lctx, i);
- vocab.id_to_token[i].type = llama_token_get_type(lctx, i);
- vocab.token_to_id.emplace(vocab.id_to_token[i].text, i);
- }
- }
-
printf("%s: tokenize training data\n", __func__);
std::vector<llama_token> train_tokens;
if (tokenize_file(lctx, params.fn_train_data, train_tokens) < 0) {
model.hparams.n_vocab = llama_n_vocab(lctx);
model.hparams.n_ctx = params.n_ctx;
model.hparams.n_embd = params.n_embd;
- model.hparams.n_mult = params.n_mult;
model.hparams.n_head = params.n_head;
model.hparams.n_layer = params.n_layer;
- model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head);
+ model.hparams.n_ff = params.n_ff;
+ // llama.cpp requires n_rot to be exactly n_embd / n_head
+ model.hparams.n_rot = model.hparams.n_embd / model.hparams.n_head;
+ model.hparams.f_norm_rms_eps = params.f_norm_rms_eps;
+ model.hparams.rope_freq_base = params.rope_freq_base;
+ model.hparams.rope_freq_scale = params.rope_freq_scale;
print_params(&model.hparams);
}
printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
- struct my_llama_kv_cache kv_self;
-
-
struct ggml_init_params lcparams;
lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb);
lcparams.mem_buffer = NULL;
lcparams.no_alloc = false;
model.ctx = ggml_init(lcparams);
- kv_self.ctx = model.ctx;
-
- my_llama_sampler sampler;
-
int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab;
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
- opt_params_adam.print_forward_graph = false;
+ opt_params_adam.print_forward_graph = false;
opt_params_adam.print_backward_graph = false;
- opt_params_adam.n_threads = params.n_threads;
- opt_params_adam.adam.n_iter = params.adam_n_iter;
- opt_params_adam.adam.sched = 1.0f;
- opt_params_adam.adam.alpha = params.adam_alpha;
- opt_params_adam.adam.decay = params.adam_decay;
-
- opt_params_lbfgs.print_forward_graph = false;
+ opt_params_adam.n_threads = params.n_threads;
+ opt_params_adam.past = params.opt_past;
+ opt_params_adam.delta = params.opt_delta;
+ opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
+ opt_params_adam.adam.n_iter = params.adam_n_iter;
+ opt_params_adam.adam.sched = 1.0f;
+ opt_params_adam.adam.alpha = params.adam_alpha;
+ opt_params_adam.adam.decay = params.adam_decay;
+ opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim;
+ opt_params_adam.adam.beta1 = params.adam_beta1;
+ opt_params_adam.adam.beta2 = params.adam_beta2;
+ opt_params_adam.adam.gclip = params.adam_gclip;
+ opt_params_adam.adam.eps_f = params.adam_eps_f;
+
+ opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_backward_graph = false;
- opt_params_lbfgs.n_threads = params.n_threads;
- opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
+ opt_params_lbfgs.n_threads = params.n_threads;
+ opt_params_adam.past = params.opt_past;
+ opt_params_adam.delta = params.opt_delta;
+ opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
+ opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
opt->ctx = model.ctx;
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
printf("%s: init model\n", __func__);
- bool existed = load_checkpoint(&model, opt, params.fn_checkpoint_in, true);
+ bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, opt);
+ if (!existed) {
+ init_model(&model);
+ }
set_param_model(&model);
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f);
}
- init_kv_cache(&kv_self, &model, 1);
- // init_kv_cache(&kv_self, &model, n_batch);
- init_sampler(&sampler, lctx);
-
- printf("used_mem model+cache: %zu bytes\n", ggml_used_mem(model.ctx));
+ printf("used_mem model: %zu bytes\n", ggml_used_mem(model.ctx));
// ggml_print_tensor_objects(model.ctx);
// TODO: use std::vector<uint8_t> intead of "new"
uint8_t * compute_addr = new uint8_t[compute_size];
size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb);
- size_t size_buf_1 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute1_gb);
uint8_t * compute_buf_0 = new uint8_t[size_buf_0];
- uint8_t * compute_buf_1 = new uint8_t[size_buf_1];
+
+ ggml_allocr * alloc = NULL;
+ if (params.use_alloc) {
+ static const size_t tensor_alignment = 32;
+ alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment);
+ }
GGML_ASSERT(n_tokens < (int) train_tokens.size());
std::vector<int> train_samples;
GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size());
}
- std::vector<uint8_t> work_buffer;
-
printf("%s: begin training\n", __func__);
+ struct opt_callback_data opt_cb_data;
+ opt_cb_data.params = ¶ms;
+ opt_cb_data.opt = opt;
+ opt_cb_data.lctx = lctx;
+ opt_cb_data.tokens_data = train_tokens.data();
+ opt_cb_data.tokens_size = train_tokens.size();
+ opt_cb_data.samples_data = train_samples.data();
+ opt_cb_data.samples_size = train_samples.size();
+ opt_cb_data.shuffle_countdown = train_samples.size();
+ opt_cb_data.tokens_input = NULL;
+ opt_cb_data.target_logits = NULL;
+ opt_cb_data.target_probs = NULL;
+
+ int64_t t0 = ggml_time_ms();
+
for (int ex = 0; ex < params.n_examples; ++ex) {
if (ex*n_batch >= (int) train_samples.size()) {
shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size());
}
struct ggml_init_params cparams = {
- /*.mem_size =*/ compute_size,
- /*.mem_buffer =*/ compute_addr,
- /*.no_alloc =*/ false,
+ compute_size, // mem_size
+ compute_addr, // mem_buffer
+ false, // no_alloc
};
struct ggml_context * ctx0 = ggml_init(cparams);
- struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
+ ggml_set_no_alloc(ctx0, false);
+
+ // don't use alloc for input tensors, so we can safely fill them with data
+ //struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
//struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
- int n_past = 0;
-
- struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
- struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
+ ggml_set_no_alloc(ctx0, (alloc != NULL));
- memset(gfbuf->data, 0, ggml_nbytes(gfbuf));
- memset(gbbuf->data, 0, ggml_nbytes(gbbuf));
+ if (alloc) {
+ ggml_allocr_reset(alloc);
+ }
- struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
- struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
+ opt_cb_data.tokens_input = tokens_input;
+ opt_cb_data.target_logits = target_logits;
+ opt_cb_data.target_probs = target_probs;
+ int n_past = 0;
- get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+ struct ggml_cgraph * gb = ggml_new_graph(ctx0);
+ struct ggml_cgraph * gb_tmp = params.use_checkpointing
+ ? ggml_new_graph(ctx0)
+ : NULL;
GGML_ASSERT(n_past == 0);
struct ggml_tensor * loss = NULL;
struct ggml_tensor * logits = NULL;
- if (params.use_scratch) {
- loss = forward_batch_wo_cache_flash_attn_train(
- &model, ctx0,
- gf, gb,
- &logits, tokens_input, target_probs,
- compute_buf_0, compute_buf_1,
- size_buf_0, size_buf_1,
- n_tokens, n_batch);
- } else if (params.use_flash) {
- logits = forward_batch_wo_cache_flash_attn(&model, ctx0, gf, tokens_input, n_tokens, n_batch);
- loss = cross_entropy_loss(ctx0, logits, target_probs);
- ggml_build_forward_expand(gf, loss);
- *gb = ggml_build_backward(ctx0, gf, true);
- } else {
- logits = forward_batch_wo_cache(&model, ctx0, gf, tokens_input, n_tokens, n_batch);
- loss = cross_entropy_loss(ctx0, logits, target_probs);
- ggml_build_forward_expand(gf, loss);
- *gb = ggml_build_backward(ctx0, gf, true);
- }
-
- ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
+ loss = llama_build_train_graphs(
+ &model, alloc, ctx0,
+ gf, gb, gb_tmp,
+ &logits, tokens_input, target_probs,
+ n_tokens, n_batch,
+ params.use_flash,
+ params.use_checkpointing
+ );
size_t used_mem_before_opt = ggml_used_mem(ctx0);
- float error_before_opt = ggml_get_f32_1d(loss, 0);
-
opt->params.adam.sched = (opt->iter < params.warmup)
? (float) opt->iter / (float) params.warmup
: cosine_decay_restart(
params.cos_decay_steps,
- params.cos_decay_alpha,
+ params.cos_decay_min,
opt->iter - params.warmup,
- params.cos_decay_restart);
+ params.cos_decay_restart,
+ params.enable_restart);
+
+ float min_sched = params.adam_min_alpha / params.adam_alpha;
+ opt->params.adam.sched = min_sched + opt->params.adam.sched * (1.0f - min_sched);
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
- ggml_opt_resume_g(ctx0, opt, loss, gf, gb);
+ ggml_opt_resume_g(ctx0, opt, loss, gf, gb, &opt_callback, (void *) &opt_cb_data);
size_t used_mem_after_opt = ggml_used_mem(ctx0);
+ int n_iter = params.use_adam ? params.adam_n_iter : params.lbfgs_n_iter;
model.train_its = opt->iter;
- model.train_samples += n_batch;
- model.train_tokens += n_batch * n_tokens;
-
- ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
-
- float error_after_opt = ggml_get_f32_1d(loss, 0);
+ model.train_samples += n_batch * n_iter;
+ model.train_tokens += n_batch * n_tokens * n_iter;
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
printf("Example %d, opt iter %d\n", ex, opt->iter);
- printf("error_before_opt: %.6f\n", error_before_opt);
- printf("error_after_opt: %.6f\n", error_after_opt);
+ printf("error_before_opt: %.6f\n", opt->loss_before);
+ printf("error_after_opt: %.6f\n", opt->loss_after);
printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt);
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
}
- if (params.print_details_interval > 0 && ex % params.print_details_interval == 0) {
- // set_logits_masked(logits, token_notavail, -1e9);
- for (int i=0; i<n_batch; ++i) {
- init_sampler(&sampler, lctx);
- for (int k=0; k<n_tokens; ++k) {
- int32_t token = sample(&sampler,
- (float *) ((char *) logits->data + i*logits->nb[2] + k*logits->nb[1]),
- (llama_token *) ((char *) tokens_input->data + i*tokens_input->nb[1]),
- k);
- * ((int32_t *) ((char *) after_opt_best_samples->data + i*after_opt_best_samples->nb[1] + k*after_opt_best_samples->nb[0])) = token;
- }
- }
-
- // printf("probabilities after optimization:\n");
- // print_matrix(after_opt_probs);
- printf("Example:\n---\n");
- print_tokens_batch(lctx, tokens_input);
- printf("\n---\n");
-
- // printf("best samples after optimization:\n---\n");
- printf("samples after optimization:\n---\n");
- print_tokens_batch(lctx, after_opt_best_samples);
- printf("\n---\n");
- }
-
ggml_free(ctx0);
}
+ int64_t t1 = ggml_time_ms();
+ int64_t d = t1-t0;
+ double dd = (double) d * 1e-3;
+ printf("%s: total training time=%f seconds\n", __func__, dd);
+
if (params.n_examples > 0) {
- save_checkpoint(&model, opt, params.fn_checkpoint_out);
+ save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt);
}
if (strlen(params.fn_model_out) > 0) {
- save_as_llama_model(&vocab, &model, params.fn_model_out);
+ save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model);
}
- {
- int n_gen = params.n_predict;
- int sample_ctx = n_tokens - n_tokens/8;
-
- sampler.params.temp = 0.2f;
- sampler.params.repeat_penalty = 1.1f;
- sampler.params.mirostat = 2;
- init_sampler(&sampler, lctx);
-
- printf("Generating %d tokens.\n", n_gen);
-
- struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens);
- struct ggml_tensor * target_logits = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
- struct ggml_tensor * target_probs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
-
- get_example_targets(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs);
- for (int i=sample_ctx; i<n_tokens; ++i) {
- ggml_set_i32_1d(tokens_input, i, n_vocab/2);
- }
-
- for (int i=0; i<sample_ctx-1; ++i) {
- print_token(lctx, ggml_get_i32_1d(tokens_input, i));
- }
-
- printf("---\n");
- for (int i=0; i<n_gen; ++i) {
- struct ggml_init_params cparams = {
- /*.mem_size =*/ compute_size,
- /*.mem_buffer =*/ compute_addr,
- /*.no_alloc =*/ false,
- };
- struct ggml_context * ctx0 = ggml_init(cparams);
-
- ggml_cgraph gf = {};
-
- int n_past = 0;
- struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
-
- ggml_build_forward_expand(&gf, logits);
- ggml_graph_compute_helper(work_buffer, &gf, params.n_threads);
-
- //struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
- //struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
-
- // set_logits_masked(logits, token_notavail, -1e9);
- int token = sample(&sampler,
- (float *) ((char *) logits->data + (sample_ctx-1)*logits->nb[1]),
- (llama_token *) tokens_input->data,
- sample_ctx-1);
- //int token = ggml_get_i32_1d(best_samples, sample_ctx-1);
-
- // print_row(probs, sample_at);
- print_token(lctx, token);
-
- lshift_examples(tokens_input, target_logits, target_probs, 1);
- ggml_set_i32_1d(tokens_input, 0, 0);
- ggml_set_i32_1d(tokens_input, sample_ctx-1, token);
-
- ggml_free(ctx0);
- }
+ if (alloc) {
+ ggml_allocr_free(alloc);
}
delete[] compute_addr;
delete[] compute_buf_0;
- delete[] compute_buf_1;
-
+ ggml_free(model.ctx);
llama_free(lctx);
llama_free_model(lmodel);
- ggml_free(model.ctx);
-
return 0;
}