/server
/simple
/batched
+/export-lora
+/finetune
/speculative
/parallel
/train-text-from-scratch
# Define the default target now so that it is always the first target
-BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative parallel tests/test-c.o
+BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative parallel finetune export-lora tests/test-c.o
# Binaries only useful for tests
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama
grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
$(CXX) $(CXXFLAGS) -c $< -o $@
+train.o: common/train.cpp common/train.h
+ $(CXX) $(CXXFLAGS) -c $< -o $@
+
libllama.so: llama.o ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o common.o $(OBJS)
+train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o common.o train.o $(OBJS)
$(CXX) $(TTFS_CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp ggml.o llama.o $(OBJS)
llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o $(OBJS)
+baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o train.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
+finetune: examples/finetune/finetune.cpp build-info.h ggml.o llama.o common.o train.o $(OBJS)
+ $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
+
+export-lora: examples/export-lora/export-lora.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+ $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
+
speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
console.cpp
grammar-parser.h
grammar-parser.cpp
+ train.h
+ train.cpp
)
if (BUILD_SHARED_LIBS)
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
}
-static void process_escapes(std::string& input) {
+void process_escapes(std::string& input) {
std::size_t input_len = input.length();
std::size_t output_idx = 0;
invalid_param = true;
break;
}
- params.lora_adapter = argv[i];
+ params.lora_adapter.push_back({argv[i], 1.0f});
+ params.use_mmap = false;
+ } else if (arg == "--lora-scaled") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ const char * lora_adapter = argv[i];
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.lora_adapter.push_back({lora_adapter, std::stof(argv[i])});
params.use_mmap = false;
} else if (arg == "--lora-base") {
if (++i >= argc) {
printf(" --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
+ printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str());
return std::make_tuple(nullptr, nullptr);
}
- if (!params.lora_adapter.empty()) {
+ for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
+ const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
+ float lora_scale = std::get<1>(params.lora_adapter[i]);
int err = llama_model_apply_lora_from_file(model,
- params.lora_adapter.c_str(),
- params.lora_base.empty() ? NULL : params.lora_base.c_str(),
+ lora_adapter.c_str(),
+ lora_scale,
+ ((i > 0) || params.lora_base.empty())
+ ? NULL
+ : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
fprintf(stream, " %d: %f", lb.first, lb.second);
}
- fprintf(stream, "lora: %s\n", params.lora_adapter.c_str());
+ fprintf(stream, "lora:\n");
+ for (std::tuple<std::string, float> la : params.lora_adapter) {
+ if (std::get<1>(la) != 1.0f) {
+ continue;
+ }
+ fprintf(stream, " - %s\n", std::get<0>(la).c_str());
+ }
+ fprintf(stream, "lora_scaled:\n");
+ for (std::tuple<std::string, float> la : params.lora_adapter) {
+ if (std::get<1>(la) == 1.0f) {
+ continue;
+ }
+ fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
+ }
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
fprintf(stream, "low_vram: %s # default: false\n", params.low_vram ? "true" : "false");
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files
- std::string lora_adapter = ""; // lora adapter path
- std::string lora_base = ""; // base model path for the lora adapter
+ std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
+ std::string lora_base = ""; // base model path for the lora adapter
int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
std::string gpt_random_prompt(std::mt19937 & rng);
+void process_escapes(std::string& input);
+
//
// Model utils
//
--- /dev/null
+#include "train.h"
+#include "common.h"
+
+#include <random>
+#include <sstream>
+#include <functional>
+
+struct random_normal_distribution {
+ std::mt19937 gen;
+ std::normal_distribution<float> rd;
+ float min;
+ float max;
+};
+
+struct random_uniform_distribution {
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> rd;
+};
+
+struct train_state * init_train_state() {
+ struct train_state * state = new struct train_state;
+ state->train_its = 0;
+ state->train_samples = 0;
+ state->train_tokens = 0;
+ state->train_epochs = 0;
+ state->shuffle_samples_hash = 0;
+ state->shuffle_sample_count = 0;
+ state->shuffle_next_sample = 0;
+ state->shuffle_rng_state_current = "";
+ state->shuffle_rng_state_next = "";
+
+ state->opt = new struct ggml_opt_context;
+ state->opt->ctx = NULL;
+ state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
+ state->opt->loss_after = 0.0f;
+
+ return state;
+}
+
+void free_train_state(struct train_state * state) {
+ delete state->opt;
+ delete state;
+}
+
+struct random_normal_distribution * init_random_normal_distribution(
+ int seed, float mean, float std, float min, float max
+) {
+ struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution));
+ rnd->gen = std::mt19937(seed);
+ rnd->rd = std::normal_distribution<float>{mean, std};
+ rnd->min = min;
+ rnd->max = max;
+ return rnd;
+}
+
+struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max) {
+ struct random_uniform_distribution * rnd = (struct random_uniform_distribution *) malloc(sizeof(struct random_uniform_distribution));
+ rnd->gen = std::mt19937(seed);
+ rnd->rd = std::uniform_real_distribution<float>{min, max};
+ return rnd;
+}
+
+void free_random_normal_distribution (struct random_normal_distribution * rnd) {
+ free(rnd);
+}
+
+void free_random_uniform_distribution(struct random_uniform_distribution * rnd) {
+ free(rnd);
+}
+
+struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
+ float scale = 1.0f; // xavier
+ switch (tensor->n_dims) {
+ case 1:
+ scale /= sqrtf((float) tensor->ne[0]);
+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+ float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
+ *dst = scale * frand_normal(rnd);
+ }
+ break;
+ case 2:
+ scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
+ for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+ float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
+ *dst = scale * frand_normal(rnd);
+ }
+ }
+ break;
+ case 3:
+ scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
+ for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
+ for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+ float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
+ *dst = scale * frand_normal(rnd);
+ }
+ }
+ }
+ break;
+ case 4:
+ scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
+ for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
+ for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
+ for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+ float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
+ *dst = scale * frand_normal(rnd);
+ }
+ }
+ }
+ }
+ break;
+ default:
+ die("Unsupported tensor->n_dims");
+ };
+ return tensor;
+}
+
+struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) {
+ switch (tensor->n_dims) {
+ case 1:
+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+ float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
+ *dst = frand_uniform(rnd);
+ }
+ break;
+ case 2:
+ for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+ float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
+ *dst = frand_uniform(rnd);
+ }
+ }
+ break;
+ case 3:
+ for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
+ for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+ float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
+ *dst = frand_uniform(rnd);
+ }
+ }
+ }
+ break;
+ case 4:
+ for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
+ for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
+ for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+ float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
+ *dst = frand_uniform(rnd);
+ }
+ }
+ }
+ }
+ break;
+ default:
+ die("Unsupported tensor->n_dims");
+ };
+ return tensor;
+}
+
+float frand() {
+ return (float)rand()/((float)(RAND_MAX) + 1.0f);
+}
+
+float frand_normal(struct random_normal_distribution * rnd) {
+ return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
+}
+
+float frand_uniform(struct random_uniform_distribution * rnd) {
+ return rnd->rd(rnd->gen);
+}
+
+int clamp(const int v, const int min, const int max) {
+ return ((v < min) ? (min) : (v > max) ? (max) : v);
+}
+
+float fclamp(const float v, const float min, const float max) {
+ return ((v < min) ? (min) : (v > max) ? (max) : v);
+}
+
+void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
+ GGML_ASSERT(tensor->n_dims == 1);
+ GGML_ASSERT(tensor->ne[0] == ne0);
+}
+
+void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
+ GGML_ASSERT(tensor->n_dims == 2);
+ GGML_ASSERT(tensor->ne[0] == ne0);
+ GGML_ASSERT(tensor->ne[1] == ne1);
+}
+
+void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
+ GGML_ASSERT(tensor->n_dims == 3);
+ GGML_ASSERT(tensor->ne[0] == ne0);
+ GGML_ASSERT(tensor->ne[1] == ne1);
+ GGML_ASSERT(tensor->ne[2] == ne2);
+}
+
+void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
+ GGML_ASSERT(tensor->n_dims == 4);
+ GGML_ASSERT(tensor->ne[0] == ne0);
+ GGML_ASSERT(tensor->ne[1] == ne1);
+ GGML_ASSERT(tensor->ne[2] == ne2);
+ GGML_ASSERT(tensor->ne[3] == ne3);
+}
+
+int64_t get_example_targets_batch(
+ struct llama_context * lctx,
+ struct ggml_tensor * tokens_input,
+ struct ggml_tensor * target_probs,
+ int64_t example_id,
+ const size_t * samples_offs,
+ const size_t * samples_begin,
+ const size_t * samples_size,
+ size_t samples_count,
+ const llama_token * train_data,
+ size_t n_train_data,
+ bool separate_with_eos,
+ bool separate_with_bos,
+ bool fill_with_next_samples,
+ bool sample_random_offsets
+) {
+ GGML_ASSERT(samples_count > 0);
+ GGML_ASSERT(tokens_input->n_dims == 2);
+ GGML_ASSERT(target_probs->n_dims == 3);
+ int64_t n_vocab = target_probs->ne[0];
+ int64_t n_tokens = tokens_input->ne[0];
+ int64_t n_batch = tokens_input->ne[1];
+ GGML_ASSERT(n_vocab == target_probs->ne[0]);
+ GGML_ASSERT(n_tokens == target_probs->ne[1]);
+ GGML_ASSERT(n_batch == target_probs->ne[2]);
+
+ int64_t used_samples = 0;
+
+ ggml_set_f32(target_probs, 0.0f);
+ llama_token bos = llama_token_bos(lctx);
+ llama_token eos = llama_token_eos(lctx);
+ // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
+ for (int k=0; k<n_batch; ++k) {
+ // printf("%s: batch %d\n", __func__, k);
+ size_t sample_idx = (example_id + used_samples) % samples_count;
+ size_t sample_offs = sample_random_offsets ? samples_offs[sample_idx] : 0;
+ size_t sample_begin = samples_begin[sample_idx];
+ size_t sample_size = samples_size[sample_idx];
+ ++used_samples;
+
+ // printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
+ GGML_ASSERT(sample_begin+sample_size-1 < n_train_data);
+
+ ggml_set_i32_nd(tokens_input, 0, k, 0, 0, bos);
+ bool sample_separation_eos = !separate_with_eos;
+ bool sample_separation_bos = !separate_with_bos;
+ for (int64_t i=0; i<n_tokens; ++i) {
+ llama_token token = eos;
+ if (sample_offs >= sample_size && fill_with_next_samples) {
+ if (!sample_separation_eos) {
+ // insert eos token to separate samples
+ sample_separation_eos = true;
+ } else if (!sample_separation_bos) {
+ // insert bos token to separate samples
+ sample_separation_bos = true;
+ token = bos;
+ } else {
+ // sample separation is done, continue with next sample
+ sample_separation_eos = !separate_with_eos;
+ sample_separation_bos = !separate_with_bos;
+ sample_offs = 0;
+ sample_idx = (example_id + used_samples) % samples_count;
+ sample_begin = samples_begin[sample_idx];
+ sample_size = samples_size[sample_idx];
+ ++used_samples;
+ }
+ }
+ // note: no else-if here
+ if (sample_offs < sample_size) {
+ token = clamp(train_data[sample_begin+sample_offs], 0, (llama_token) (n_vocab - 1));
+ ++sample_offs;
+ }
+ ggml_set_f32_nd(target_probs, token, (int) i, (int) k, 0, +1.0f);
+ if (i+1<n_tokens) {
+ ggml_set_i32_nd(tokens_input, (int) (i + 1), (int) k, 0, 0, token);
+ }
+ }
+ }
+
+ return used_samples;
+}
+
+void mt19937_set_state(std::mt19937& rng, const std::string& rng_state) {
+ std::stringstream s_rng_state;
+ s_rng_state.imbue(std::locale::classic());
+ s_rng_state.exceptions(std::stringstream::failbit);
+ s_rng_state.str(rng_state);
+ s_rng_state >> rng;
+}
+
+std::string mt19937_get_state(const std::mt19937& rng) {
+ std::stringstream s_rng_state;
+ s_rng_state.imbue(std::locale::classic());
+ s_rng_state << rng;
+ return s_rng_state.str();
+}
+
+std::string mt19937_seed_to_state(unsigned seed) {
+ std::mt19937 rng(seed);
+ return mt19937_get_state(rng);
+}
+
+std::string shuffle_samples(
+ const std::string & rng_state,
+ size_t * shuffled_offs,
+ size_t * shuffled_begins,
+ size_t * shuffled_sizes,
+ const size_t * begins,
+ const size_t * sizes,
+ size_t count) {
+ if (count == 0) return rng_state;
+
+ std::mt19937 rng;
+ mt19937_set_state(rng, rng_state);
+
+ // sort indices by random value for each index
+ std::vector<size_t> idcs;
+ {
+ std::vector<unsigned> rnd;
+ idcs.resize(count);
+ rnd.resize(count);
+ for (unsigned i=0; i<count; ++i) {
+ idcs[i] = i;
+ rnd[i] = rng();
+ }
+
+ std::sort(idcs.begin(), idcs.end(), [&rnd](size_t a, size_t b){
+ // stable sort for reproducibility
+ return (rnd[a] == rnd[b]) ? (a < b) : (rnd[a] < rnd[b]);
+ });
+ }
+
+ // create random offsets
+ for (unsigned i=0; i<count; ++i) {
+ shuffled_offs[i] = (size_t) ((sizes[idcs[i]] - 1) * ((double) rng() / (double) (rng.max()-1)));
+ }
+
+ // reorder begins and sizes by sorted indices
+ for (unsigned i=0; i<count; ++i) {
+ shuffled_begins[i] = begins[idcs[i]];
+ }
+
+ for (unsigned i=0; i<count; ++i) {
+ shuffled_sizes[i] = sizes[idcs[i]];
+ }
+
+ return mt19937_get_state(rng);
+}
+
+size_t hash_combine(size_t h1, size_t h2) {
+ return h1 ^ (h2 << 1);
+}
+
+size_t compute_samples_hash(const char* fn, const size_t* samples_begin, const size_t* samples_size, size_t sample_count) {
+ std::hash<std::string> h_string;
+ std::hash<unsigned long long> h_ull;
+ size_t h = h_string(std::string(fn));
+ h = hash_combine(h, h_ull((unsigned long long) sample_count));
+ for (size_t i=0; i< sample_count; ++i) {
+ h = hash_combine(h, h_ull((unsigned long long) samples_begin[i]));
+ h = hash_combine(h, h_ull((unsigned long long) samples_size[i]));
+ }
+ return h;
+}
+
+std::string replace_str(const char * s, const char * needle, const char * replacement) {
+ std::string str = s;
+ size_t pos = str.find(needle);
+ if (pos != std::string::npos) {
+ str.replace(pos, strlen(needle), replacement);
+ }
+ return str;
+}
+
+void print_duration(double fmillis) {
+ if (fmillis < 1000.0f) {
+ printf("%.1fms", (float) fmillis);
+ return;
+ }
+ const int64_t one_sec = 1000;
+ const int64_t one_min = one_sec * 60;
+ const int64_t one_hour = one_min * 60;
+ const int64_t one_day = one_hour * 24;
+
+ int64_t millis = (int64_t) fmillis;
+ int64_t days = millis/one_day;
+ int64_t hours = (millis - days*one_day)/one_hour;
+ int64_t minutes = (millis - days*one_day - hours*one_hour)/one_min;
+ int64_t seconds = (millis - days*one_day - hours*one_hour - minutes*one_min)/one_sec;
+
+ // to print int64_t either cast to (long long int) or use macro PRId64 from <inttypes.h>
+ if (days > 0) {
+ printf("%lldd ", (long long int) days);
+ }
+ printf("%02lld:%02lld:%02lld", (long long int) hours, (long long int) minutes, (long long int) seconds);
+}
+
+float cosine_decay(int64_t step, int64_t decay_steps, float minimum) {
+ if (step > decay_steps) {
+ step = decay_steps;
+ }
+ const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
+ const float decay = (1 - minimum)*cosine_decay + minimum;
+ return decay;
+}
+
+float cosine_decay_restart(int64_t step, int64_t decay_steps, float minimum, float restart_step_mult) {
+ while (step > decay_steps) {
+ step -= decay_steps;
+ decay_steps = (int64_t) (restart_step_mult * decay_steps);
+ }
+ return cosine_decay(step, decay_steps, minimum);
+}
+
+float learning_schedule(
+ int64_t step,
+ int64_t warmup_steps,
+ int64_t cos_decay_steps,
+ float learning_rate,
+ float overall_minimum,
+ float cos_decay_minimum,
+ float cos_decay_restart_step_mult,
+ bool enable_restart) {
+
+ float result =
+ (step < warmup_steps)
+ ? (float) step / (float) warmup_steps
+ : enable_restart
+ ? cosine_decay_restart(
+ step - warmup_steps,
+ cos_decay_steps,
+ cos_decay_minimum,
+ cos_decay_restart_step_mult)
+ : cosine_decay(
+ step,
+ cos_decay_steps,
+ cos_decay_minimum);
+
+ float min = overall_minimum / learning_rate;
+ result = min + result * (1.0f - min);
+ return result;
+}
+
+static bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) {
+ GGML_ASSERT(a != NULL);
+ GGML_ASSERT(b != NULL);
+ GGML_ASSERT(a->type == b->type);
+ GGML_ASSERT(ggml_are_same_shape(a, b));
+ GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b));
+
+ return true;
+}
+
+void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) {
+ if (dst == NULL) {
+ return;
+ }
+ struct ggml_tensor * t = ggml_get_tensor(ctx, name);
+ GGML_ASSERT(are_same_layout(dst, t));
+ memcpy(dst->data, t->data, ggml_nbytes(t));
+
+ if (strlen(ggml_get_name(dst)) == 0) {
+ ggml_set_name(dst, name);
+ }
+}
+
+// gguf constants
+static const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type";
+static const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam";
+static const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs";
+static const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version";
+static const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count";
+static const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count";
+static const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count";
+static const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized";
+static const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss";
+static const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss";
+static const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count";
+static const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count";
+static const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss";
+static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step";
+static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j";
+static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k";
+static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end";
+static const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count";
+
+static const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments";
+static const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments";
+static const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values";
+
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
+static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
+
+static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
+static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
+static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
+static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
+static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count";
+static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
+static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state";
+static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
+static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample";
+
+#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
+{ \
+ const std::string skey(key); \
+ const int kid = gguf_find_key(ctx, skey.c_str()); \
+ if (kid >= 0) { \
+ enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
+ if (ktype != (type)) { \
+ die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
+ } \
+ (dst) = func(ctx, kid); \
+ } else if (req) { \
+ die_fmt("key not found in model: %s", skey.c_str()); \
+ } \
+}
+
+void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
+ // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
+
+ uint32_t file_version;
+ GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION);
+ GGML_ASSERT(file_version == 0);
+
+ GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
+ GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
+ GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
+
+ uint64_t nx;
+ GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT);
+ opt->nx = (size_t) nx;
+
+ // don't call ggml_opt_init until optimizer type and optimizer specific parameters are know
+
+ std::string opt_type;
+ GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
+ if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
+ opt->params.type = GGML_OPT_ADAM;
+
+ GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
+ GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
+ GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT);
+
+ ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
+
+ copy_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
+ copy_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
+ copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
+ } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
+ opt->params.type = GGML_OPT_LBFGS;
+
+ GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
+ GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
+ GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP);
+ GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J);
+ GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K);
+ GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END);
+ GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT);
+
+ ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
+
+ copy_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
+ copy_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
+ copy_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
+ copy_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
+ copy_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
+ copy_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
+ copy_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
+ copy_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
+ copy_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
+ copy_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
+ } else {
+ die("unknown optimizer type\n");
+ }
+}
+
+void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) {
+ gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0);
+ gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past);
+ gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx);
+ gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
+ gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
+
+ switch (opt->params.type) {
+ case GGML_OPT_ADAM:
+ {
+ gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
+ gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best);
+ gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev);
+ gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement);
+
+ ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
+ ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
+ if (opt->adam.pf) {
+ ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
+ }
+
+ gguf_add_tensor(fctx, opt->adam.m);
+ gguf_add_tensor(fctx, opt->adam.v);
+ if (opt->adam.pf) {
+ gguf_add_tensor(fctx, opt->adam.pf);
+ }
+ } break;
+ case GGML_OPT_LBFGS:
+ {
+ gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
+ gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);
+ gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best);
+ gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step);
+ gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j);
+ gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k);
+ gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end);
+ gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement);
+
+ ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
+ ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
+ ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
+ ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
+ ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
+ if (opt->lbfgs.pf) {
+ ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
+ }
+ ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
+ ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
+ ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
+ ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
+
+ gguf_add_tensor(fctx, opt->lbfgs.x);
+ gguf_add_tensor(fctx, opt->lbfgs.xp);
+ gguf_add_tensor(fctx, opt->lbfgs.g);
+ gguf_add_tensor(fctx, opt->lbfgs.gp);
+ gguf_add_tensor(fctx, opt->lbfgs.d);
+ if (opt->lbfgs.pf) {
+ gguf_add_tensor(fctx, opt->lbfgs.pf);
+ }
+ gguf_add_tensor(fctx, opt->lbfgs.lmal);
+ gguf_add_tensor(fctx, opt->lbfgs.lmys);
+ gguf_add_tensor(fctx, opt->lbfgs.lms);
+ gguf_add_tensor(fctx, opt->lbfgs.lmy);
+ } break;
+ }
+}
+
+bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train) {
+ if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) < 0) {
+ return false;
+ }
+
+ uint32_t file_version;
+ GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
+ GGML_ASSERT(file_version <= 1);
+
+ if (file_version == 0) {
+
+ GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
+ GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
+ GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
+
+ } else if (file_version == 1) {
+
+ GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT);
+ GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT);
+ GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT);
+ GGUF_GET_KEY(fctx, train->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT);
+
+ GGUF_GET_KEY(fctx, train->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH);
+ GGUF_GET_KEY(fctx, train->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE);
+ GGUF_GET_KEY(fctx, train->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT);
+ GGUF_GET_KEY(fctx, train->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE);
+ }
+
+ load_opt_context_gguf(fctx, f_ggml_ctx, train->opt);
+ return true;
+}
+
+void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) {
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1);
+ gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its);
+ gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples);
+ gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens);
+ gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, train->train_epochs);
+
+ gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) train->shuffle_samples_hash);
+ gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, train->shuffle_rng_state_current.c_str());
+ gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) train->shuffle_sample_count);
+ gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) train->shuffle_next_sample);
+
+ save_opt_context_gguf(fctx, train->opt);
+}
+
+
+struct llama_file {
+ // use FILE * so we don't have to re-open the file to mmap
+ FILE * fp;
+ size_t size;
+
+ llama_file(const char * fname, const char * mode) {
+ fp = std::fopen(fname, mode);
+ if (fp == NULL) {
+ size = 0;
+ } else {
+ seek(0, SEEK_END);
+ size = tell();
+ seek(0, SEEK_SET);
+ }
+ }
+
+ size_t tell() const {
+#ifdef _WIN32
+ __int64 ret = _ftelli64(fp);
+#else
+ long ret = std::ftell(fp);
+#endif
+ GGML_ASSERT(ret != -1); // this really shouldn't fail
+ return (size_t) ret;
+ }
+
+ void seek(size_t offset, int whence) {
+#ifdef _WIN32
+ int ret = _fseeki64(fp, (__int64) offset, whence);
+#else
+ int ret = std::fseek(fp, (long) offset, whence);
+#endif
+ GGML_ASSERT(ret == 0); // same
+ }
+
+ void read_raw(void * ptr, size_t size) {
+ if (size == 0) {
+ return;
+ }
+ errno = 0;
+ std::size_t ret = std::fread(ptr, size, 1, fp);
+ if (ferror(fp)) {
+ die_fmt("read error: %s", strerror(errno));
+ }
+ if (ret != 1) {
+ die("unexpectedly reached end of file");
+ }
+ }
+
+ std::uint32_t read_u32() {
+ std::uint32_t ret;
+ read_raw(&ret, sizeof(ret));
+ return ret;
+ }
+
+ std::string read_string(std::uint32_t len) {
+ std::vector<char> chars(len);
+ read_raw(chars.data(), len);
+ return std::string(chars.data(), len);
+ }
+
+ void write_raw(const void * ptr, size_t size) {
+ if (size == 0) {
+ return;
+ }
+ errno = 0;
+ size_t ret = std::fwrite(ptr, size, 1, fp);
+ if (ret != 1) {
+ die_fmt("write error: %s", strerror(errno));
+ }
+ }
+
+ void write_u32(std::uint32_t val) {
+ write_raw(&val, sizeof(val));
+ }
+
+ ~llama_file() {
+ if (fp) {
+ std::fclose(fp);
+ }
+ }
+};
+
+static size_t utf8_len(char src) {
+ const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+ uint8_t highbits = static_cast<uint8_t>(src) >> 4;
+ return lookup[highbits];
+}
+
+// mark each byte with its utf8 unit number.
+// returns the number of utf8 characters.
+// e.g. when bytes == '\x61\xD0\xB0\x62',
+// then utf8_units will become [0,0,1,0]
+// utf8_nunits will become [1,2,2,1] and 3 is returned.
+// bytes where utf8_units is zero, are the begin of an utf8 character.
+static size_t mark_utf8_units(const char* bytes, int * utf8_units, int * utf8_nunits, size_t count) {
+ size_t offs = 0;
+ size_t count_utf8 = 0;
+ while(offs < count) {
+ int len = (int) utf8_len(bytes[offs]);
+ for (int i=0; i<len; ++i) {
+ utf8_units[offs+i] = i;
+ utf8_nunits[offs+i] = len;
+ }
+ offs += len;
+ ++count_utf8;
+ }
+ return count_utf8;
+}
+
+size_t tokenize_file(
+ struct llama_context * lctx,
+ const char * filename,
+ const std::string & sample_start,
+ bool include_sample_start,
+ bool overlapping_samples,
+ unsigned context_length,
+ std::vector<llama_token> & out_tokens,
+ std::vector<size_t> & out_samples_begin,
+ std::vector<size_t> & out_samples_size) {
+ struct llama_file f(filename, "rb");
+
+ if (f.size == 0) {
+ out_tokens.clear();
+ out_samples_begin.clear();
+ out_samples_size.clear();
+ printf("%s: warning: empty or not existing training data file '%s'\n",
+ __func__, filename);
+ return out_tokens.size();
+ }
+
+ // account for possible leading whitespace that will be added by tokenizer
+ // e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]
+ const int n_max_tokens_overhead = 1;
+
+ std::vector<char> buf;
+ buf.resize(f.size);
+
+ f.read_raw(buf.data(), f.size);
+
+ std::vector<int> utf8_units;
+ std::vector<int> utf8_nunits;
+ utf8_units.resize(buf.size());
+ utf8_nunits.resize(buf.size());
+ mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size());
+
+ if (sample_start.size() == 0) {
+ // tokenize all data at once
+ out_tokens.resize(buf.size() + n_max_tokens_overhead);
+
+ int n_tokens = llama_tokenize(
+ lctx,
+ buf.data(),
+ (int) buf.size(),
+ out_tokens.data(),
+ (int) out_tokens.size(),
+ false);
+ if (n_tokens < 0) {
+ out_tokens.resize(-n_tokens);
+ n_tokens = llama_tokenize(
+ lctx,
+ buf.data(),
+ (int) buf.size(),
+ out_tokens.data(),
+ (int) out_tokens.size(),
+ false);
+ }
+ if (n_tokens >= 0) {
+ out_tokens.resize(n_tokens);
+ }
+
+ // generate sample starts at all token positions
+ out_samples_begin.clear();
+ out_samples_begin.push_back(0);
+ out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size()));
+ size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0;
+ for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) {
+ out_samples_begin.push_back(sample_begin);
+ out_samples_size.push_back(context_length);
+ }
+ } else {
+ // split data into samples and tokenize each sample
+ std::string data_str(buf.data(), buf.size());
+ out_samples_begin.clear();
+ out_samples_size.clear();
+ out_tokens.clear();
+
+ // find all positions of pattern sample_start
+ size_t sample_begin = data_str.find(sample_start, 0);
+ while (sample_begin != std::string::npos) {
+ out_samples_begin.push_back(sample_begin);
+ const size_t search_start = sample_begin + sample_start.size();
+ sample_begin = data_str.find(sample_start, search_start);
+ }
+ if (out_samples_begin.size() == 0) {
+ printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n",
+ __func__, sample_start.c_str());
+ out_samples_begin.push_back(0);
+ }
+
+ out_samples_size.resize(out_samples_begin.size(), 0);
+
+ std::vector<char> buf_sample;
+ std::vector<llama_token> tok_sample;
+
+ const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size());
+ size_t found_too_big_sample = 0;
+ size_t found_too_small_sample = 0;
+ size_t found_empty_sample = 0;
+ size_t found_min_sample_size = SIZE_MAX;
+ size_t found_max_sample_size = 0;
+
+ size_t max_token_text_size = 0;
+ int n_vocab = llama_n_vocab(lctx);
+ for (llama_token token=0; token < n_vocab; ++token) {
+ max_token_text_size = std::max(
+ max_token_text_size,
+ strlen(llama_token_get_text(lctx, token)));
+ }
+
+ // upper bound of context byte length.
+ // strings with this byte length should always tokenize to at least context_length tokens.
+ size_t context_byte_len = max_token_text_size*context_length;
+
+ for (unsigned i=0; i<out_samples_begin.size(); ++i) {
+ // determine sample begin and end from pattern positions
+ size_t sample_begin = out_samples_begin[i] + sample_begin_offset;
+ size_t sample_end = overlapping_samples
+ ? std::min(
+ data_str.size(),
+ sample_begin + context_byte_len)
+ : (i+1 < out_samples_begin.size()
+ ? out_samples_begin[i+1]
+ : data_str.size());
+ if (sample_end < utf8_units.size() && utf8_units[sample_end] > 0) {
+ // sample end is in the middle of an utf8 character.
+ // advance sample_end to the begin of the next utf8 character.
+ sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];
+ }
+ size_t sample_size = sample_end - sample_begin;
+ if (sample_size == 0) {
+ ++found_empty_sample;
+ }
+
+ if (sample_size > 0) {
+ // llama_tokenize expects zero terminated string,
+ // copy sample into buffer and zero terminate it.
+ buf_sample.resize(sample_size);
+ memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size);
+
+ // printf("sample: '%s'\n", buf_sample.data());
+
+ // tokenize the sample
+ tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
+ int n_tokens = llama_tokenize(lctx,
+ buf_sample.data(),
+ (int) buf_sample.size(),
+ tok_sample.data(),
+ (int) tok_sample.size(),
+ false);
+ if (n_tokens < 0) {
+ tok_sample.resize(-n_tokens);
+ n_tokens = llama_tokenize(lctx,
+ buf_sample.data(),
+ (int) buf_sample.size(),
+ tok_sample.data(),
+ (int) tok_sample.size(),
+ false);
+ GGML_ASSERT(n_tokens >= 0);
+ }
+ GGML_ASSERT(n_tokens <= (int) tok_sample.size());
+
+ if ((size_t) n_tokens > context_length) {
+ ++found_too_big_sample;
+ } else if ((size_t) n_tokens < context_length) {
+ ++found_too_small_sample;
+ }
+ found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens);
+ found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens);
+
+ // write out tokens, start and size of sample
+ // overwrite the string start position with the token start position
+ out_samples_begin[i] = out_tokens.size();
+ out_samples_size[i] = (size_t) n_tokens;
+ out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens);
+ } else {
+ out_samples_begin[i] = out_tokens.size();
+ out_samples_size[i] = 0;
+ }
+
+ }
+ if (found_too_big_sample > 0) {
+ printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n",
+ __func__, found_too_big_sample, found_max_sample_size, context_length);
+ }
+
+ if (found_too_small_sample > 0) {
+ printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n",
+ __func__, found_too_small_sample, found_min_sample_size, context_length);
+ }
+
+ if (found_empty_sample) {
+ printf("%s: warning: found %zu empty samples.\n",
+ __func__, found_empty_sample);
+ }
+ }
+ printf("%s: total number of samples: %zu\n",
+ __func__, out_samples_begin.size());
+
+ GGML_ASSERT(out_samples_begin.size() == out_samples_size.size());
+
+ return out_tokens.size();
+}
+
+std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration) {
+ std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
+ return replace_str(filename, pattern_it, sit.c_str());
+}
+
+struct train_params_common get_default_train_params_common() {
+ struct train_params_common params;
+ params.fn_train_data = "shakespeare.txt";
+ params.fn_checkpoint_in = "checkpoint.gguf";
+ params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
+ params.pattern_fn_it = "ITERATION";
+ params.fn_latest = "LATEST";
+
+ params.print_usage = false;
+
+ params.save_every = 10;
+
+ params.seed = -1;
+
+ params.n_ctx = 128;
+ params.n_threads = 6;
+ params.n_batch = 8;
+ params.n_gradient_accumulation = 1;
+ params.n_epochs = -1;
+
+ params.custom_n_ctx = false;
+
+ params.use_flash = true;
+ params.use_checkpointing = true;
+
+ params.sample_start = "";
+ params.include_sample_start = false;
+ params.escape = false;
+ params.overlapping_samples = false;
+ params.fill_with_next_samples = false;
+ params.separate_with_eos = false;
+ params.separate_with_bos = true;
+ params.sample_random_offsets = false;
+ params.force_reshuffle = false;
+
+ params.opt_past = 0;
+ params.opt_delta = 1e-5f;
+ params.opt_max_no_improvement = 0;
+
+ params.warmup = 100;
+ params.cos_decay_steps = 1000;
+ params.cos_decay_restart = 1.1f;
+ params.cos_decay_min = 0.1f;
+ params.enable_restart = false;
+
+ params.adam_n_iter = 256;
+ params.adam_alpha = 1e-3f;
+ params.adam_min_alpha = 0;
+ params.adam_decay = 1e-1f;
+ params.adam_decay_min_ndim = 2;
+ params.adam_beta1 = 0.9f;
+ params.adam_beta2 = 0.999f;
+ params.adam_gclip = 1.0f;
+ params.adam_eps_f = 0.0f;
+ return params;
+}
+
+void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train_params_common * params) {
+ // fprintf(stderr, "usage: %s [options]\n", argv[0]);
+ // fprintf(stderr, "\n");
+ // fprintf(stderr, "options:\n");
+ // fprintf(stderr, " -h, --help show this help message and exit\n");
+ fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
+ fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
+ fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
+ fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
+ fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
+ fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
+ fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
+ fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
+ fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
+ fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
+ fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
+ fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
+ fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
+ fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
+ fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
+ fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
+ fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
+ fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
+ fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
+ fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
+ fprintf(stderr, " --sample-random-offsets Use samples beginning at random offsets. Together with fill-with-next-samples this may help for training endless text generation.%s\n", params->sample_random_offsets ? " (default)" : "");
+ fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
+ fprintf(stderr, " --no-flash Don't use flash attention \n");
+ fprintf(stderr, " --use-flash Use flash attention (default)\n");
+ fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
+ fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
+ fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
+ fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
+ fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
+ fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
+ fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
+ fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
+ fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
+ fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
+ fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
+ fprintf(stderr, " --epochs N Maximum number epochs to process. (default %d)\n", params->n_epochs);
+ fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
+ fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
+ fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
+ fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
+ fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
+ fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
+ fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
+ fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
+ fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
+ fprintf(stderr, "\n");
+}
+
+bool consume_common_train_arg(
+ int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param
+) {
+ int& i = *idx;
+ std::string arg = argv[i];
+ const std::string arg_prefix = "--";
+ if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
+ std::replace(arg.begin(), arg.end(), '_', '-');
+ }
+ if (arg == "--train-data") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->fn_train_data = argv[i];
+ } else if (arg == "--checkpoint-in") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->fn_checkpoint_in = argv[i];
+ } else if (arg == "--checkpoint-out") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->fn_checkpoint_out = argv[i];
+ } else if (arg == "--pattern-fn-it") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->pattern_fn_it = argv[i];
+ } else if (arg == "--fn-latest") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->fn_latest = argv[i];
+ } else if (arg == "--save-every") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->save_every = std::stoi(argv[i]);
+ } else if (arg == "-s" || arg == "--seed") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->seed = std::stoi(argv[i]);
+ } else if (arg == "-c" || arg == "--ctx") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->n_ctx = std::stoi(argv[i]);
+ params->custom_n_ctx = true;
+ } else if (arg == "-t" || arg == "--threads") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->n_threads = std::stoi(argv[i]);
+ } else if (arg == "-b" || arg == "--batch") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->n_batch = std::stoi(argv[i]);
+ } else if (arg == "--grad-acc") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
+ } else if (arg == "--sample-start") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->sample_start = std::string(argv[i]);
+ } else if (arg == "--escape") {
+ params->escape = true;
+ } else if (arg == "--include-sample-start") {
+ params->include_sample_start = true;
+ } else if (arg == "--overlapping-samples") {
+ params->overlapping_samples = true;
+ } else if (arg == "--fill-with-next-samples") {
+ params->fill_with_next_samples = true;
+ } else if (arg == "--separate-with-eos") {
+ params->separate_with_eos = true;
+ } else if (arg == "--separate-with-bos") {
+ params->separate_with_bos = true;
+ } else if (arg == "--no-separate-with-eos") {
+ params->separate_with_eos = false;
+ } else if (arg == "--no-separate-with-bos") {
+ params->separate_with_bos = false;
+ } else if (arg == "--sample-random-offsets") {
+ params->sample_random_offsets = true;
+ } else if (arg == "--force-reshuffle") {
+ params->force_reshuffle = true;
+ } else if (arg == "--no-flash") {
+ params->use_flash = false;
+ } else if (arg == "--use-flash") {
+ params->use_flash = true;
+ } else if (arg == "--no-checkpointing") {
+ params->use_checkpointing = false;
+ } else if (arg == "--use-checkpointing") {
+ params->use_checkpointing = true;
+ } else if (arg == "--warmup") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->warmup = std::stoi(argv[i]);
+ } else if (arg == "--cos-decay-steps") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->cos_decay_steps = std::stoi(argv[i]);
+ } else if (arg == "--cos-decay-restart") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->cos_decay_restart = std::stof(argv[i]);
+ } else if (arg == "--cos-decay-min") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->cos_decay_min = std::stof(argv[i]);
+ } else if (arg == "--enable-restart") {
+ params->enable_restart = true;
+ } else if (arg == "--disable-restart") {
+ params->enable_restart = false;
+ } else if (arg == "--opt-past") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->opt_past = std::stoi(argv[i]);
+ } else if (arg == "--opt-delta") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->opt_delta = std::stof(argv[i]);
+ } else if (arg == "--opt-max-no-improvement") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->opt_max_no_improvement = std::stoi(argv[i]);
+ } else if (arg == "--adam-epsf") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->adam_eps_f = std::stof(argv[i]);
+ } else if (arg == "--epochs") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->n_epochs = std::stoi(argv[i]);
+ } else if (arg == "--adam-iter") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->adam_n_iter = std::stoi(argv[i]);
+ } else if (arg == "--adam-alpha") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->adam_alpha = std::stof(argv[i]);
+ } else if (arg == "--adam-min-alpha") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->adam_min_alpha = std::stof(argv[i]);
+ } else if (arg == "--adam-decay") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->adam_decay = std::stof(argv[i]);
+ } else if (arg == "--adam-decay-min-ndim") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->adam_decay_min_ndim = std::stoi(argv[i]);
+ } else if (arg == "--adam-beta1") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->adam_beta1 = std::stof(argv[i]);
+ } else if (arg == "--adam-beta2") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->adam_beta2 = std::stof(argv[i]);
+ } else if (arg == "--adam-gclip") {
+ if (++i >= argc) {
+ *invalid_param = true;
+ return true;
+ }
+ params->adam_gclip = std::stof(argv[i]);
+ } else if (arg == "-h" || arg == "--help") {
+ params->print_usage = true;
+ return true;
+ } else {
+ return false;
+ }
+ return true;
+}
+
+void finish_processing_train_args(struct train_params_common * params) {
+ if (params->escape) {
+ process_escapes(params->sample_start);
+ }
+}
+
+void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel) {
+ struct train_opt_callback_data * data = (struct train_opt_callback_data *) vdata;
+ struct train_params_common * params = data->params;
+ struct train_state * train = data->train;
+ struct ggml_opt_context * opt = train->opt;
+ int n_batch = params->n_batch;
+ int n_ctx = params->n_ctx;
+
+ if (accum_step == 0) {
+ // time measurement
+ int64_t now = ggml_time_ms();
+ if (now > data->last_time && opt->iter > data->first_iter) {
+ double dt = (double) (now - data->last_time);
+ if (data->millis_per_iter == 0.0) {
+ data->millis_per_iter = dt;
+ } else {
+ const double gain = 0.7;
+ data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain;
+ }
+ }
+
+ double remaining_millis = 0.0;
+ if (data->millis_per_iter > 0.0) {
+ const int n_iter = params->adam_n_iter;
+ const int done_iter = opt->iter - data->first_iter;
+ const int remaining_iter = n_iter - done_iter;
+ remaining_millis = remaining_iter * data->millis_per_iter;
+ }
+
+ // file saving
+ const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
+ if (save_now) {
+ int new_iters = opt->iter - data->last_save_iter;
+ train->train_its += new_iters;
+ train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
+
+ if (data->save_cb) {
+ data->save_cb(data->save_data, train);
+ }
+
+ data->last_save_iter = opt->iter;
+ }
+
+ // exclude file saving from time measurement, by measuring last_time after saving
+ data->last_time = ggml_time_ms();
+
+ *sched = learning_schedule(
+ opt->iter,
+ params->warmup,
+ params->cos_decay_steps,
+ params->adam_alpha,
+ params->adam_min_alpha,
+ params->cos_decay_min,
+ params->cos_decay_restart,
+ params->enable_restart);
+
+ int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
+ if (impr_plot > 0) impr_plot = 0;
+ if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0;
+ printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
+ __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
+ *sched, opt->loss_after);
+
+
+ if (data->millis_per_iter > 0) {
+ printf(" dt=");
+ print_duration(data->millis_per_iter);
+ printf(" eta=");
+ print_duration(remaining_millis);
+ }
+
+ float improvement = opt->loss_before - opt->loss_after;
+ const float plot_scale = 10.0f;
+ int bar_len = (int)(1 + improvement*plot_scale + 0.5);
+ printf(" |");
+ for (int i=0; i<bar_len; ++i) {
+ printf("-");
+ }
+ printf(">");
+ printf("\n");
+ }
+
+ int64_t used_samples = get_example_targets_batch(
+ data->lctx,
+ data->tokens_input,
+ data->target_probs,
+ train->shuffle_next_sample,
+ data->shuffled_samples_offs,
+ data->shuffled_samples_begin,
+ data->shuffled_samples_size,
+ data->samples_count,
+ data->tokens_data,
+ data->tokens_size,
+ params->separate_with_eos,
+ params->separate_with_bos,
+ params->fill_with_next_samples,
+ params->sample_random_offsets);
+
+ train->train_samples += used_samples;
+ train->shuffle_next_sample += used_samples;
+
+ if (train->shuffle_next_sample >= train->shuffle_sample_count) {
+ ++train->train_epochs;
+ printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs);
+ // note: we may have used some samples from the current shuffling more than once
+ train->shuffle_rng_state_current = train->shuffle_rng_state_next;
+ train->shuffle_rng_state_next = shuffle_samples(
+ train->shuffle_rng_state_current,
+ data->shuffled_samples_offs,
+ data->shuffled_samples_begin,
+ data->shuffled_samples_size,
+ data->samples_begin,
+ data->samples_size,
+ data->samples_count);
+ train->shuffle_next_sample = 0;
+ }
+
+ const bool last_epoch_reached = (params->n_epochs > 0 && (int64_t) train->train_epochs - data->first_epoch >= params->n_epochs);
+ if (last_epoch_reached) {
+ // allow optimization iteration at last epoch to be completed before canceling
+ if (data->iter_at_last_epoch < 0) {
+ data->iter_at_last_epoch = opt->iter;
+ } else if (opt->iter > data->iter_at_last_epoch) {
+ *cancel = true;
+ }
+ }
+}
--- /dev/null
+// Various helper functions and utilities for training
+
+#pragma once
+
+#include <string>
+#include <random>
+#include <vector>
+
+#include "ggml.h"
+#include "llama.h"
+
+typedef std::string mt19937_state;
+
+struct train_state {
+ struct ggml_opt_context * opt;
+
+ uint64_t train_its;
+ uint64_t train_samples;
+ uint64_t train_tokens;
+ uint64_t train_epochs;
+
+ size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes)
+ mt19937_state shuffle_rng_state_current;
+ mt19937_state shuffle_rng_state_next;
+ size_t shuffle_sample_count;
+ size_t shuffle_next_sample;
+};
+
+struct train_params_common {
+ const char * fn_train_data;
+ const char * fn_checkpoint_in;
+ const char * fn_checkpoint_out;
+ const char * pattern_fn_it;
+ const char * fn_latest;
+
+ bool print_usage;
+
+ int save_every;
+
+ uint32_t seed;
+
+ int n_ctx;
+ int n_threads;
+ int n_batch;
+ int n_gradient_accumulation;
+ int n_epochs;
+
+ bool custom_n_ctx;
+
+ bool use_flash;
+ bool use_checkpointing;
+
+ std::string sample_start;
+ bool include_sample_start;
+ bool escape;
+ bool overlapping_samples;
+ bool fill_with_next_samples;
+ bool separate_with_eos;
+ bool separate_with_bos;
+ bool sample_random_offsets;
+
+ bool force_reshuffle;
+
+ int warmup;
+ int cos_decay_steps;
+ float cos_decay_restart;
+ float cos_decay_min;
+ bool enable_restart;
+
+ int opt_past;
+ float opt_delta;
+ int opt_max_no_improvement;
+
+ int adam_n_iter;
+ float adam_alpha;
+ float adam_min_alpha;
+ float adam_decay;
+ int adam_decay_min_ndim;
+ float adam_beta1;
+ float adam_beta2;
+ float adam_gclip;
+ float adam_eps_f;
+};
+
+typedef void (*save_train_files_callback)(void * data, struct train_state * train);
+
+struct train_opt_callback_data {
+ struct train_params_common * params;
+ struct train_state * train;
+ save_train_files_callback save_cb;
+ void * save_data;
+ struct llama_context * lctx;
+ int last_save_iter;
+ llama_token * tokens_data;
+ size_t tokens_size;
+ size_t * samples_begin;
+ size_t * samples_size;
+ size_t * shuffled_samples_offs;
+ size_t * shuffled_samples_begin;
+ size_t * shuffled_samples_size;
+ size_t samples_count;
+ struct ggml_tensor * tokens_input;
+ struct ggml_tensor * target_probs;
+ int first_iter;
+ int first_epoch;
+ int iter_at_last_epoch;
+ int64_t last_time;
+ double millis_per_iter;
+};
+
+struct train_state * init_train_state();
+void free_train_state(struct train_state * state);
+
+struct train_params_common get_default_train_params_common();
+void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params);
+
+bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param);
+void finish_processing_train_args(struct train_params_common * params);
+
+struct random_normal_distribution;
+struct random_uniform_distribution;
+
+struct random_normal_distribution * init_random_normal_distribution (int seed, float mean, float std, float min, float max);
+struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max);
+
+void free_random_normal_distribution (struct random_normal_distribution * rnd);
+void free_random_uniform_distribution(struct random_uniform_distribution * rnd);
+
+struct ggml_tensor * randomize_tensor_normal (struct ggml_tensor * tensor, struct random_normal_distribution * rnd);
+struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd);
+
+// generate random float in interval [0,1)
+float frand();
+float frand_normal (struct random_normal_distribution * rnd);
+float frand_uniform(struct random_uniform_distribution * rnd);
+
+int clamp (const int v, const int min, const int max);
+float fclamp(const float v, const float min, const float max);
+
+void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0);
+void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1);
+void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2);
+void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3);
+
+size_t tokenize_file(
+ struct llama_context * lctx,
+ const char * filename,
+ const std::string & sample_start,
+ bool include_sample_start,
+ bool overlapping_samples,
+ unsigned context_length,
+ std::vector<llama_token> & out_tokens,
+ std::vector<size_t> & out_samples_begin,
+ std::vector<size_t> & out_samples_size);
+
+int64_t get_example_targets_batch(
+ struct llama_context * lctx,
+ struct ggml_tensor * tokens_input,
+ struct ggml_tensor * target_probs,
+ int64_t example_id,
+ const size_t * samples_offs,
+ const size_t * samples_begin,
+ const size_t * samples_size,
+ size_t samples_count,
+ const llama_token * train_data,
+ size_t n_train_data,
+ bool separate_with_eos,
+ bool separate_with_bos,
+ bool fill_with_next_samples,
+ bool sample_random_offsets);
+
+
+void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state);
+mt19937_state mt19937_get_state(const std::mt19937& rng);
+mt19937_state mt19937_seed_to_state(unsigned seed);
+
+mt19937_state shuffle_samples(
+ const mt19937_state & rng_state,
+ size_t * shuffled_offs,
+ size_t * shuffled_begins,
+ size_t * shuffled_sizes,
+ const size_t * begins,
+ const size_t * sizes,
+ size_t count);
+
+size_t hash_combine(size_t h1, size_t h2);
+
+size_t compute_samples_hash(
+ const char* fn,
+ const size_t* samples_begin,
+ const size_t* samples_size,
+ size_t sample_count);
+
+
+std::string replace_str(const char * s, const char * needle, const char * replacement);
+
+void print_duration(double milliseconds);
+
+float cosine_decay(
+ int64_t step,
+ int64_t decay_steps,
+ float minimum);
+
+float cosine_decay_restart(
+ int64_t step,
+ int64_t decay_steps,
+ float minimum,
+ float restart_step_mult);
+
+float learning_schedule(
+ int64_t step,
+ int64_t warmup_steps,
+ int64_t decay_steps,
+ float learning_rate,
+ float overall_minimum,
+ float cos_decay_minimum,
+ float cos_decay_restart_step_mult,
+ bool enable_restart);
+
+void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name);
+
+void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt);
+void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt);
+
+bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train);
+void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train);
+
+std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration);
+
+void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel);
add_subdirectory(benchmark)
add_subdirectory(baby-llama)
add_subdirectory(train-text-from-scratch)
+ add_subdirectory(finetune)
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(simple)
add_subdirectory(batched)
if (LLAMA_BUILD_SERVER)
add_subdirectory(server)
endif()
+ add_subdirectory(export-lora)
endif()
#include "ggml.h"
+#include "train.h"
#include <vector>
#include <cassert>
#include <random>
constexpr float rms_norm_eps = 5e-6f;
#endif
-static float frand() {
- return (float)rand()/(float)RAND_MAX;
-}
-
-struct random_normal_distribution {
- std::mt19937 gen;
- std::normal_distribution<float> nd;
- float min;
- float max;
-};
-
-static void init_random_normal_distribution(
- struct random_normal_distribution * rnd, int seed, float mean, float std, float min, float max
-) {
- rnd->gen = std::mt19937(seed);
- rnd->nd = std::normal_distribution<float>{mean, std};
- rnd->min = min;
- rnd->max = max;
-}
-
-static float frand_normal(struct random_normal_distribution * rnd) {
- const float r = rnd->nd(rnd->gen);
- return ((r < rnd->min) ? (rnd->min) : (r > rnd->max) ? (rnd->max) : r);
-}
-
static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
return tensor;
}
-static struct ggml_tensor * randomize_tensor_normal(
- struct ggml_tensor * tensor, int ndims, const int64_t ne[], struct random_normal_distribution * rnd
-) {
- float scale = 1.0; // xavier
- switch (ndims) {
- case 1:
- scale /= sqrtf(ne[0]);
- for (int i0 = 0; i0 < ne[0]; i0++) {
- ((float *)tensor->data)[i0] = scale * frand_normal(rnd);
- }
- break;
- case 2:
- scale /= sqrtf(ne[0]+ne[1]);
- for (int i1 = 0; i1 < ne[1]; i1++) {
- for (int i0 = 0; i0 < ne[0]; i0++) {
- ((float *)tensor->data)[i1*ne[0] + i0] = scale * frand_normal(rnd);
- }
- }
- break;
- case 3:
- scale /= sqrtf(ne[0]+ne[1]);
- for (int i2 = 0; i2 < ne[2]; i2++) {
- for (int i1 = 0; i1 < ne[1]; i1++) {
- for (int i0 = 0; i0 < ne[0]; i0++) {
- ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd);
- }
- }
- }
- break;
- case 4:
- scale /= sqrtf(ne[0]+ne[1]);
- for (int i3 = 0; i3 < ne[3]; i3++) {
- for (int i2 = 0; i2 < ne[2]; i2++) {
- for (int i1 = 0; i1 < ne[1]; i1++) {
- for (int i0 = 0; i0 < ne[0]; i0++) {
- ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd);
- }
- }
- }
- }
- break;
- default:
- assert(false);
- };
-
- return tensor;
-}
-
struct llama_hparams {
uint32_t n_vocab = 32000;
uint32_t n_ctx = 512; // this is provided as user input?
const uint32_t n_layer = hparams.n_layer;
- struct random_normal_distribution rnd;
- init_random_normal_distribution(&rnd, seed, mean, std, min, max);
- randomize_tensor_normal(model->tok_embeddings, model->tok_embeddings->n_dims, model->tok_embeddings->ne, &rnd);
- randomize_tensor_normal(model->norm, model->norm->n_dims, model->norm->ne, &rnd);
- randomize_tensor_normal(model->output, model->output->n_dims, model->output->ne, &rnd);
+ struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
+
+ randomize_tensor_normal(model->tok_embeddings , rnd);
+ randomize_tensor_normal(model->norm , rnd);
+ randomize_tensor_normal(model->output , rnd);
for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = model->layers[i];
- randomize_tensor_normal(layer.attention_norm, layer.attention_norm->n_dims, layer.attention_norm->ne, &rnd);
+ randomize_tensor_normal(layer.attention_norm, rnd);
- randomize_tensor_normal(layer.wq, layer.wq->n_dims, layer.wq->ne, &rnd);
- randomize_tensor_normal(layer.wk, layer.wk->n_dims, layer.wk->ne, &rnd);
- randomize_tensor_normal(layer.wv, layer.wv->n_dims, layer.wv->ne, &rnd);
- randomize_tensor_normal(layer.wo, layer.wo->n_dims, layer.wo->ne, &rnd);
+ randomize_tensor_normal(layer.wq, rnd);
+ randomize_tensor_normal(layer.wk, rnd);
+ randomize_tensor_normal(layer.wv, rnd);
+ randomize_tensor_normal(layer.wo, rnd);
- randomize_tensor_normal(layer.ffn_norm, layer.ffn_norm->n_dims, layer.ffn_norm->ne, &rnd);
+ randomize_tensor_normal(layer.ffn_norm, rnd);
- randomize_tensor_normal(layer.w1, layer.w1->n_dims, layer.w1->ne, &rnd);
- randomize_tensor_normal(layer.w2, layer.w2->n_dims, layer.w2->ne, &rnd);
- randomize_tensor_normal(layer.w3, layer.w3->n_dims, layer.w3->ne, &rnd);
+ randomize_tensor_normal(layer.w1, rnd);
+ randomize_tensor_normal(layer.w2, rnd);
+ randomize_tensor_normal(layer.w3, rnd);
}
+
+ free_random_normal_distribution(rnd);
}
const uint32_t n_layer = hparams.n_layer;
- struct random_normal_distribution rnd;
- init_random_normal_distribution(&rnd, seed, mean, std, min, max);
- randomize_tensor_normal(model->tok_embeddings, model->tok_embeddings->n_dims, model->tok_embeddings->ne, &rnd);
- randomize_tensor_normal(model->norm, model->norm->n_dims, model->norm->ne, &rnd);
- randomize_tensor_normal(model->outputa, model->outputa->n_dims, model->outputa->ne, &rnd);
- randomize_tensor_normal(model->outputb, model->outputb->n_dims, model->outputb->ne, &rnd);
+ struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
+
+ randomize_tensor_normal(model->tok_embeddings, rnd);
+ randomize_tensor_normal(model->norm , rnd);
+ randomize_tensor_normal(model->outputa , rnd);
+ randomize_tensor_normal(model->outputb , rnd);
for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = model->layers[i];
- randomize_tensor_normal(layer.attention_norm, layer.attention_norm->n_dims, layer.attention_norm->ne, &rnd);
-
- randomize_tensor_normal(layer.wqa, layer.wqa->n_dims, layer.wqa->ne, &rnd);
- randomize_tensor_normal(layer.wqb, layer.wqb->n_dims, layer.wqb->ne, &rnd);
- randomize_tensor_normal(layer.wka, layer.wka->n_dims, layer.wka->ne, &rnd);
- randomize_tensor_normal(layer.wkb, layer.wkb->n_dims, layer.wkb->ne, &rnd);
- randomize_tensor_normal(layer.wva, layer.wva->n_dims, layer.wva->ne, &rnd);
- randomize_tensor_normal(layer.wvb, layer.wvb->n_dims, layer.wvb->ne, &rnd);
- randomize_tensor_normal(layer.woa, layer.woa->n_dims, layer.woa->ne, &rnd);
- randomize_tensor_normal(layer.wob, layer.wob->n_dims, layer.wob->ne, &rnd);
-
- randomize_tensor_normal(layer.ffn_norm, layer.ffn_norm->n_dims, layer.ffn_norm->ne, &rnd);
-
- randomize_tensor_normal(layer.w1, layer.w1->n_dims, layer.w1->ne, &rnd);
- randomize_tensor_normal(layer.w2, layer.w2->n_dims, layer.w2->ne, &rnd);
- randomize_tensor_normal(layer.w3, layer.w3->n_dims, layer.w3->ne, &rnd);
+ randomize_tensor_normal(layer.attention_norm, rnd);
+
+ randomize_tensor_normal(layer.wqa, rnd);
+ randomize_tensor_normal(layer.wqb, rnd);
+ randomize_tensor_normal(layer.wka, rnd);
+ randomize_tensor_normal(layer.wkb, rnd);
+ randomize_tensor_normal(layer.wva, rnd);
+ randomize_tensor_normal(layer.wvb, rnd);
+ randomize_tensor_normal(layer.woa, rnd);
+ randomize_tensor_normal(layer.wob, rnd);
+
+ randomize_tensor_normal(layer.ffn_norm, rnd);
+
+ randomize_tensor_normal(layer.w1, rnd);
+ randomize_tensor_normal(layer.w2, rnd);
+ randomize_tensor_normal(layer.w3, rnd);
}
+
+ free_random_normal_distribution(rnd);
}
static bool init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model, int n_batch) {
return inpL;
}
-static void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
- GGML_ASSERT(tensor->n_dims == 1);
- GGML_ASSERT(tensor->ne[0] == ne0);
-}
-
-static void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
- GGML_ASSERT(tensor->n_dims == 2);
- GGML_ASSERT(tensor->ne[0] == ne0);
- GGML_ASSERT(tensor->ne[1] == ne1);
-}
-
-static void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
- GGML_ASSERT(tensor->n_dims == 3);
- GGML_ASSERT(tensor->ne[0] == ne0);
- GGML_ASSERT(tensor->ne[1] == ne1);
- GGML_ASSERT(tensor->ne[2] == ne2);
-}
-
-static void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
- GGML_ASSERT(tensor->n_dims == 4);
- GGML_ASSERT(tensor->ne[0] == ne0);
- GGML_ASSERT(tensor->ne[1] == ne1);
- GGML_ASSERT(tensor->ne[2] == ne2);
- GGML_ASSERT(tensor->ne[3] == ne3);
-}
-
static struct ggml_tensor * forward_batch(
struct llama_model * model,
struct llama_kv_cache * cache,
--- /dev/null
+set(TARGET export-lora)
+add_executable(${TARGET} export-lora.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
--- /dev/null
+# export-lora
+
+Apply LORA adapters to base model and export the resulting model.
+
+```
+usage: export-lora [options]
+
+options:
+ -h, --help show this help message and exit
+ -m FNAME, --model-base FNAME model path from which to load base model (default '')
+ -o FNAME, --model-out FNAME path to save exported model (default '')
+ -l FNAME, --lora FNAME apply LoRA adapter
+ -s FNAME S, --lora-scaled FNAME S apply LoRA adapter with user defined scaling S
+ -t N, --threads N number of threads to use during computation (default: 4)
+```
+
+For example:
+
+```bash
+./bin/export-lora \
+ -m open-llama-3b-v2-q8_0.gguf \
+ -o open-llama-3b-v2-q8_0-english2tokipona-chat.gguf \
+ -l lora-open-llama-3b-v2-q8_0-english2tokipona-chat-LATEST.bin
+```
+
+Multiple LORA adapters can be applied by passing multiple `-l FN` or `-s FN S` command line parameters.
--- /dev/null
+
+#include "common.h"
+#include "ggml.h"
+#include "ggml-alloc.h"
+
+#include <vector>
+#include <string>
+#include <thread>
+
+static const size_t tensor_alignment = 32;
+
+struct lora_info {
+ std::string filename;
+ float scale;
+};
+
+struct export_lora_params {
+ std::string fn_model_base;
+ std::string fn_model_out;
+ std::vector<struct lora_info> lora;
+ int n_threads;
+};
+
+struct lora_data {
+ struct lora_info info;
+ std::vector<uint8_t> data;
+ struct ggml_context * ctx;
+
+ uint32_t lora_r;
+ uint32_t lora_alpha;
+};
+
+struct llama_file {
+ // use FILE * so we don't have to re-open the file to mmap
+ FILE * fp;
+ size_t size;
+
+ llama_file(const char * fname, const char * mode) {
+ fp = std::fopen(fname, mode);
+ if (fp == NULL) {
+ size = 0;
+ } else {
+ seek(0, SEEK_END);
+ size = tell();
+ seek(0, SEEK_SET);
+ }
+ }
+
+ size_t tell() const {
+#ifdef _WIN32
+ __int64 ret = _ftelli64(fp);
+#else
+ long ret = std::ftell(fp);
+#endif
+ GGML_ASSERT(ret != -1); // this really shouldn't fail
+ return (size_t) ret;
+ }
+
+ void seek(size_t offset, int whence) {
+#ifdef _WIN32
+ int ret = _fseeki64(fp, (__int64) offset, whence);
+#else
+ int ret = std::fseek(fp, (long) offset, whence);
+#endif
+ GGML_ASSERT(ret == 0); // same
+ }
+
+ void read_raw(void * ptr, size_t size) {
+ if (size == 0) {
+ return;
+ }
+ errno = 0;
+ std::size_t ret = std::fread(ptr, size, 1, fp);
+ if (ferror(fp)) {
+ die_fmt("read error: %s", strerror(errno));
+ }
+ if (ret != 1) {
+ die("unexpectedly reached end of file");
+ }
+ }
+
+ std::uint32_t read_u32() {
+ std::uint32_t ret;
+ read_raw(&ret, sizeof(ret));
+ return ret;
+ }
+
+ std::string read_string(std::uint32_t len) {
+ std::vector<char> chars(len);
+ read_raw(chars.data(), len);
+ return std::string(chars.data(), len);
+ }
+
+ void write_raw(const void * ptr, size_t size) {
+ if (size == 0) {
+ return;
+ }
+ errno = 0;
+ size_t ret = std::fwrite(ptr, size, 1, fp);
+ if (ret != 1) {
+ die_fmt("write error: %s", strerror(errno));
+ }
+ }
+
+ void write_u32(std::uint32_t val) {
+ write_raw(&val, sizeof(val));
+ }
+
+ bool eof() {
+ return tell() >= size;
+ }
+
+ ~llama_file() {
+ if (fp) {
+ std::fclose(fp);
+ }
+ }
+};
+
+static struct export_lora_params get_default_export_lora_params() {
+ struct export_lora_params result;
+ result.fn_model_base = "";
+ result.fn_model_out = "";
+ result.n_threads = GGML_DEFAULT_N_THREADS;
+ return result;
+}
+
+static void export_lora_print_usage(int /*argc*/, char ** argv, const struct export_lora_params * params) {
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
+ fprintf(stderr, "\n");
+ fprintf(stderr, "options:\n");
+ fprintf(stderr, " -h, --help show this help message and exit\n");
+ fprintf(stderr, " -m FNAME, --model-base FNAME model path from which to load base model (default '%s')\n", params->fn_model_base.c_str());
+ fprintf(stderr, " -o FNAME, --model-out FNAME path to save exported model (default '%s')\n", params->fn_model_out.c_str());
+ fprintf(stderr, " -l FNAME, --lora FNAME apply LoRA adapter\n");
+ fprintf(stderr, " -s FNAME S, --lora-scaled FNAME S apply LoRA adapter with user defined scaling S\n");
+ fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params->n_threads);
+}
+
+static bool export_lora_params_parse(int argc, char ** argv, struct export_lora_params * params) {
+ bool invalid_param = false;
+ std::string arg;
+ struct export_lora_params default_params = get_default_export_lora_params();
+ const std::string arg_prefix = "--";
+
+ for (int i = 1; i < argc; i++) {
+ arg = argv[i];
+ if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
+ std::replace(arg.begin(), arg.end(), '_', '-');
+ }
+
+ if (arg == "-m" || arg == "--model-base") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->fn_model_base = argv[i];
+ } else if (arg == "-o" || arg == "--model-out") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->fn_model_out = argv[i];
+ } else if (arg == "-l" || arg == "--lora") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ struct lora_info lora;
+ lora.filename = argv[i];
+ lora.scale = 1.0f;
+ params->lora.push_back(lora);
+ } else if (arg == "-s" || arg == "--lora-scaled") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ struct lora_info lora;
+ lora.filename = argv[i];
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ lora.scale = std::stof(argv[i]);
+ params->lora.push_back(lora);
+ } else if (arg == "-t" || arg == "--threads") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_threads = std::stoi(argv[i]);
+ if (params->n_threads <= 0) {
+ params->n_threads = std::thread::hardware_concurrency();
+ }
+ } else {
+ fprintf(stderr, "error: unknown argument: '%s'\n", arg.c_str());
+ export_lora_print_usage(argc, argv, &default_params);
+ exit(1);
+ }
+ }
+
+ if (params->fn_model_base == default_params.fn_model_base) {
+ fprintf(stderr, "error: please specify a filename for model-base.\n");
+ export_lora_print_usage(argc, argv, &default_params);
+ exit(1);
+ }
+ if (params->fn_model_out == default_params.fn_model_out) {
+ fprintf(stderr, "error: please specify a filename for model-out.\n");
+ export_lora_print_usage(argc, argv, &default_params);
+ exit(1);
+ }
+ if (invalid_param) {
+ fprintf(stderr, "error: invalid parameter for argument: '%s'\n", arg.c_str());
+ export_lora_print_usage(argc, argv, &default_params);
+ exit(1);
+ }
+ return true;
+}
+
+static void free_lora(struct lora_data * lora) {
+ if (lora->ctx != NULL) {
+ ggml_free(lora->ctx);
+ }
+ delete lora;
+}
+
+static struct lora_data * load_lora(struct lora_info * info) {
+ struct lora_data * result = new struct lora_data;
+ result->info = *info;
+ result->ctx = NULL;
+ result->lora_r = 1;
+ result->lora_alpha = 1;
+
+ struct llama_file file(info->filename.c_str(), "rb");
+ if (file.fp == NULL) {
+ fprintf(stderr, "warning: Could not open lora adapter '%s'. Ignoring this adapter.\n",
+ info->filename.c_str());
+ free_lora(result);
+ return NULL;
+ }
+
+ struct ggml_init_params params_ggml;
+ params_ggml.mem_size = ggml_tensor_overhead() * GGML_MAX_NODES;
+ params_ggml.mem_buffer = NULL;
+ params_ggml.no_alloc = true;
+ result->ctx = ggml_init(params_ggml);
+
+ uint32_t LLAMA_FILE_MAGIC_LORA = 0x67676C61; // 'ggla'
+ uint32_t magic = file.read_u32();
+ if (magic != LLAMA_FILE_MAGIC_LORA) {
+ die_fmt("unexpected lora header file magic in '%s'", info->filename.c_str());
+ }
+ uint32_t version = file.read_u32();
+ if (version != 1) {
+ die_fmt("unexpected lora file version '%u' in '%s'", (unsigned) version, info->filename.c_str());
+ }
+ result->lora_r = file.read_u32();
+ result->lora_alpha = file.read_u32();
+ // read tensor infos from file
+ std::vector<char> name_buf;
+ std::vector<struct ggml_tensor *> tensors;
+ std::vector<size_t> tensors_offset;
+ size_t total_nbytes_pad = 0;
+ while(!file.eof()) {
+ int64_t ne[4] = {1,1,1,1};
+ uint32_t n_dims = file.read_u32();
+ uint32_t namelen = file.read_u32();
+ uint32_t type = file.read_u32();
+ for (uint32_t k = 0; k < n_dims; ++k) {
+ ne[k] = (int64_t)file.read_u32();
+ }
+ name_buf.clear();
+ name_buf.resize(namelen + 1, '\0');
+ file.read_raw(name_buf.data(), namelen);
+ file.seek((0-file.tell()) & 31, SEEK_CUR);
+ size_t offset = file.tell();
+ struct ggml_tensor * tensor = ggml_new_tensor(result->ctx, (enum ggml_type) type, n_dims, ne);
+ ggml_set_name(tensor, name_buf.data());
+ size_t nbytes = ggml_nbytes(tensor);
+ size_t nbytes_pad = ggml_nbytes_pad(tensor);
+ total_nbytes_pad += nbytes_pad;
+ tensors.push_back(tensor);
+ tensors_offset.push_back(offset);
+ file.seek(nbytes, SEEK_CUR);
+ }
+ // read tensor data
+ result->data.resize(total_nbytes_pad);
+ size_t data_offset = 0;
+ for (size_t i = 0; i < tensors.size(); ++i) {
+ struct ggml_tensor * tensor = tensors[i];
+ size_t offset = tensors_offset[i];
+ size_t nbytes = ggml_nbytes(tensor);
+ size_t nbytes_pad = ggml_nbytes_pad(tensor);
+ file.seek(offset, SEEK_SET);
+ tensor->data = result->data.data() + data_offset;
+ file.read_raw(tensor->data, nbytes);
+ data_offset += nbytes_pad;
+ }
+ return result;
+}
+
+
+static struct ggml_cgraph * build_graph_lora(
+ struct ggml_context * ctx,
+ struct ggml_tensor * tensor,
+ struct ggml_tensor * lora_a,
+ struct ggml_tensor * lora_b,
+ float scaling
+) {
+ struct ggml_tensor * ab = ggml_mul_mat(ctx, lora_a, lora_b);
+ if (scaling != 1.0f) {
+ ab = ggml_scale(ctx, ab, ggml_new_f32(ctx, scaling));
+ }
+ struct ggml_tensor * res = ggml_add_inplace(ctx, tensor, ab);
+
+ struct ggml_cgraph * gf = ggml_new_graph(ctx);
+ ggml_build_forward_expand (gf, res);
+ return gf;
+}
+
+static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int n_threads) {
+ if (lora->ctx == NULL) {
+ return false;
+ }
+ std::string name = ggml_get_name(tensor);
+ std::string name_a = name + std::string(".loraA");
+ std::string name_b = name + std::string(".loraB");
+ struct ggml_tensor * lora_a = ggml_get_tensor(lora->ctx, name_a.c_str());
+ struct ggml_tensor * lora_b = ggml_get_tensor(lora->ctx, name_b.c_str());
+ if (lora_a == NULL || lora_b == NULL) {
+ return false;
+ }
+
+ float scaling = lora->info.scale * (float)lora->lora_alpha / (float)lora->lora_r;
+
+ struct ggml_init_params params;
+ params.mem_size = GGML_OBJECT_SIZE + GGML_GRAPH_SIZE + ggml_tensor_overhead()*4 + GGML_MEM_ALIGN*5;
+ params.mem_buffer = NULL;
+ params.no_alloc = true;
+ struct ggml_context * ctx = NULL;
+ struct ggml_allocr * alloc = NULL;
+ struct ggml_cgraph * gf = NULL;
+
+ ctx = ggml_init(params);
+ alloc = ggml_allocr_new_measure(tensor_alignment);
+ gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
+ size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf);
+ ggml_allocr_free(alloc);
+ ggml_free(ctx);
+
+ static std::vector<uint8_t> data_compute;
+ data_compute.resize(alloc_size + tensor_alignment);
+
+ ctx = ggml_init(params);
+ alloc = ggml_allocr_new(data_compute.data(), data_compute.size(), tensor_alignment);
+ gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
+ ggml_allocr_alloc_graph(alloc, gf);
+ ggml_allocr_free(alloc);
+
+ struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads);
+ static std::vector<uint8_t> data_work;
+ data_work.resize(cplan.work_size);
+ cplan.work_data = data_work.data();
+
+ ggml_graph_compute(gf, &cplan);
+
+ ggml_free(ctx);
+ return true;
+}
+
+static void export_lora(struct export_lora_params * params) {
+ // load all loras
+ std::vector<struct lora_data *> loras;
+ for (size_t i = 0; i < params->lora.size(); ++i) {
+ struct lora_data * lora = load_lora(¶ms->lora[i]);
+ if (lora != NULL) {
+ loras.push_back(lora);
+ }
+ }
+ if (loras.size() == 0) {
+ fprintf(stderr, "warning: no lora adapters will be applied.\n");
+ }
+
+ // open input file
+ struct llama_file fin(params->fn_model_base.c_str(), "rb");
+ if (!fin.fp) {
+ die_fmt("Could not open file '%s'\n", params->fn_model_base.c_str());
+ }
+
+ // open base model gguf, read tensors without their data
+ struct ggml_context * ctx_in;
+ struct gguf_init_params params_gguf;
+ params_gguf.no_alloc = true;
+ params_gguf.ctx = &ctx_in;
+ struct gguf_context * gguf_in = gguf_init_from_file(params->fn_model_base.c_str(), params_gguf);
+
+ // create new gguf
+ struct gguf_context * gguf_out = gguf_init_empty();
+
+ // copy meta data from base model: kv and tensors
+ gguf_set_kv(gguf_out, gguf_in);
+ int n_tensors = gguf_get_n_tensors(gguf_in);
+ for (int i=0; i < n_tensors; ++i) {
+ const char * name = gguf_get_tensor_name(gguf_in, i);
+ struct ggml_tensor * tensor = ggml_get_tensor(ctx_in, name);
+ gguf_add_tensor(gguf_out, tensor);
+ }
+
+ // create output file
+ struct llama_file fout(params->fn_model_out.c_str(), "wb");
+ if (!fout.fp) {
+ die_fmt("Could not create file '%s'\n", params->fn_model_out.c_str());
+ }
+
+ // write gguf meta data
+ std::vector<uint8_t> meta;
+ meta.resize(gguf_get_meta_size(gguf_out));
+ gguf_get_meta_data(gguf_out, meta.data());
+ fout.write_raw(meta.data(), meta.size());
+
+ std::vector<uint8_t> data;
+ std::vector<uint8_t> padding;
+ for (int i=0; i < n_tensors; ++i) {
+ const char * name = gguf_get_tensor_name(gguf_in, i);
+ struct ggml_tensor * tensor = ggml_get_tensor(ctx_in, name);
+
+ // read tensor data
+ data.resize(ggml_nbytes(tensor));
+ tensor->data = data.data();
+ size_t offset = gguf_get_tensor_offset(gguf_in, i);
+ fin.seek(offset + meta.size(), SEEK_SET);
+ fin.read_raw(data.data(), data.size());
+
+ // apply all loras
+ for (size_t k = 0; k < loras.size(); ++k) {
+ apply_lora(tensor, loras[k], params->n_threads);
+ }
+
+ // write tensor data + padding
+ padding.clear();
+ padding.resize(GGML_PAD(data.size(), gguf_get_alignment(gguf_out)) - data.size(), 0);
+
+ GGML_ASSERT(fout.tell() == offset + meta.size());
+ // fout.seek(offset + meta.size(), SEEK_SET);
+ fout.write_raw(data.data(), data.size());
+ fout.write_raw(padding.data(), padding.size());
+
+ if (i % 2 == 0) {
+ printf(".");
+ }
+ }
+ printf("\n");
+
+ // close gguf
+ gguf_free(gguf_out);
+ gguf_free(gguf_in);
+
+ // free loras
+ for (size_t i = 0; i < loras.size(); ++i) {
+ free_lora(loras[i]);
+ }
+}
+
+int main(int argc, char ** argv) {
+ struct export_lora_params params = get_default_export_lora_params();
+
+ if (!export_lora_params_parse(argc, argv, ¶ms)) {
+ return 1;
+ }
+
+ export_lora(¶ms);
+
+ return 0;
+}
--- /dev/null
+set(TARGET finetune)
+add_executable(${TARGET} finetune.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
--- /dev/null
+# finetune
+
+Basic usage instructions:
+
+```bash
+# get training data
+wget https://raw.githubusercontent.com/brunoklein99/deep-learning-notes/master/shakespeare.txt
+
+# finetune LORA adapter
+./bin/finetune \
+ --model-base open-llama-3b-v2-q8_0.gguf \
+ --checkpoint-in chk-lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.gguf \
+ --checkpoint-out chk-lora-open-llama-3b-v2-q8_0-shakespeare-ITERATION.gguf \
+ --lora-out lora-open-llama-3b-v2-q8_0-shakespeare-ITERATION.bin \
+ --train-data "shakespeare.txt" \
+ --save-every 10 \
+ --threads 6 --adam-iter 30 --batch 4 --ctx 64 \
+ --use-checkpointing
+
+# predict
+./bin/main -m open-llama-3b-v2-q8_0.gguf --lora lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin
+```
+
+Finetune output files will be saved every N iterations (config with `--save-every N`).
+The pattern 'ITERATION' in the output filenames will be replaced with the iteration number and with 'LATEST' for the latest output.
+So in above example after 10 iterations these files will be written:
+- chk-lora-open-llama-3b-v2-q8_0-shakespeare-10.gguf
+- chk-lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.gguf
+- lora-open-llama-3b-v2-q8_0-shakespeare-10.bin
+- lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin
+
+After 10 more iterations:
+- chk-lora-open-llama-3b-v2-q8_0-shakespeare-20.gguf
+- chk-lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.gguf
+- lora-open-llama-3b-v2-q8_0-shakespeare-20.bin
+- lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin
+
+Checkpoint files (`--checkpoint-in FN`, `--checkpoint-out FN`) store the training process. When the input checkpoint file does not exist, it will begin finetuning a new randomly initialized adapter.
+
+llama.cpp compatible LORA adapters will be saved with filename specified by `--lora-out FN`.
+These LORA adapters can then be used by `main` together with the base model, like in the 'predict' example command above.
+
+In `main` you can also load multiple LORA adapters, which will then be mixed together.
+
+For example if you have two LORA adapters `lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin` and `lora-open-llama-3b-v2-q8_0-bible-LATEST.bin`, you can mix them together like this:
+
+```bash
+./bin/main -m open-llama-3b-v2-q8_0.gguf \
+ --lora lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin \
+ --lora lora-open-llama-3b-v2-q8_0-bible-LATEST.bin
+```
+
+You can change how strong each LORA adapter is applied to the base model by using `--lora-scaled FN SCALE` instead of `--lora FN`.
+
+For example to apply 40% of the 'shakespeare' LORA adapter, 80% of the 'bible' LORA adapter and 100% of yet another one:
+
+```bash
+./bin/main -m open-llama-3b-v2-q8_0.gguf \
+ --lora-scaled lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin 0.4 \
+ --lora-scaled lora-open-llama-3b-v2-q8_0-bible-LATEST.bin 0.8 \
+ --lora lora-open-llama-3b-v2-q8_0-yet-another-one-LATEST.bin
+```
+
+The scale numbers don't need to add up to one, and you can also use numbers creater than 1 to further increase the influence of an adapter. But making the values to big will sometimes result in worse output. Play around to find good values.
+
+Gradient checkpointing reduces the memory requirements by ~50% but increases the runtime.
+If you have enough RAM, you can make finetuning a bit faster by disabling checkpointing with `--no-checkpointing`.
+
+The default LORA rank can be specified with `--lora-r N`.
+The LORA rank can be configured for each model tensor type separately with these command line options:
+
+```bash
+ --lora-r N LORA r: default rank. Also specifies resulting scaling together with lora-alpha. (default 4)
+ --rank-att-norm N LORA rank for attention norm tensor (default 1)
+ --rank-ffn-norm N LORA rank for feed-forward norm tensor (default 1)
+ --rank-out-norm N LORA rank for output norm tensor (default 1)
+ --rank-tok-embd N LORA rank for token embeddings tensor (default 4)
+ --rank-out N LORA rank for output tensor (default 4)
+ --rank-wq N LORA rank for wq tensor (default 4)
+ --rank-wk N LORA rank for wk tensor (default 4)
+ --rank-wv N LORA rank for wv tensor (default 4)
+ --rank-wo N LORA rank for wo tensor (default 4)
+ --rank-w1 N LORA rank for w1 tensor (default 4)
+ --rank-w2 N LORA rank for w2 tensor (default 4)
+ --rank-w3 N LORA rank for w3 tensor (default 4)
+```
+
+The LORA rank of 'norm' tensors should always be 1.
+
+To see all available options use `finetune --help`.
--- /dev/null
+#!/usr/bin/env python3
+# finetune checkpoint --> gguf conversion
+
+import argparse
+import gguf
+import os
+import struct
+import sys
+import numpy as np
+from pathlib import Path
+
+# gguf constants
+LLM_KV_OPTIMIZER_TYPE = "optimizer.type"
+LLM_KV_OPTIMIZER_TYPE_ADAM = "adam"
+LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs"
+LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version"
+LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count"
+LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count"
+LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count"
+LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized"
+LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss"
+LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss"
+LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count"
+LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count"
+LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss"
+LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step"
+LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j"
+LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k"
+LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end"
+LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count"
+
+LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments"
+LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments"
+LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values"
+
+LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters"
+LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters"
+LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients"
+LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients"
+LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction"
+LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values"
+LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha"
+LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys"
+LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"
+LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"
+
+LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"
+LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"
+LLM_KV_TRAINING_TYPE = "training.type"
+LLM_KV_TRAINING_FILE_VERSION = "training.file_version"
+LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"
+LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"
+LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"
+
+LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD = "training.lora.rank.token_embd"
+LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm"
+LLM_KV_TRAINING_LORA_RANK_OUTPUT = "training.lora.rank.output"
+LLM_KV_TRAINING_LORA_RANK_ATTN_NORM = "training.lora.rank.attn_norm"
+LLM_KV_TRAINING_LORA_RANK_ATTN_Q = "training.lora.rank.attn_q"
+LLM_KV_TRAINING_LORA_RANK_ATTN_K = "training.lora.rank.attn_k"
+LLM_KV_TRAINING_LORA_RANK_ATTN_V = "training.lora.rank.attn_v"
+LLM_KV_TRAINING_LORA_RANK_ATTN_OUT = "training.lora.rank.attn_output"
+LLM_KV_TRAINING_LORA_RANK_FFN_NORM = "training.lora.rank.ffn_norm"
+LLM_KV_TRAINING_LORA_RANK_FFN_GATE = "training.lora.rank.ffn_gate"
+LLM_KV_TRAINING_LORA_RANK_FFN_DOWN = "training.lora.rank.ffn_down"
+LLM_KV_TRAINING_LORA_RANK_FFN_UP = "training.lora.rank.ffn_up"
+
+class Tensor:
+ def __init__(self, dtype='f', ne=None):
+ if ne is None:
+ ne = []
+ self.dtype = dtype
+ self.ne = ne
+ self.nbytes = 0
+ if self.dtype == 'f':
+ if len(self.ne) == 0:
+ self.nbytes = 0
+ else:
+ self.nbytes = int(np.product(self.ne)) * 4
+ else:
+ raise ValueError(f"Unhandled data type '{self.dtype}'")
+
+ def load(self, data, offset):
+ nd = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ namelen = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ dtype = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+
+ assert(nd == len(self.ne))
+ ne = []
+ for d in range(nd):
+ n = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ ne.append(n)
+
+ if tuple(ne) != tuple(self.ne):
+ raise ValueError(f"Tensor.load: Expected number of elements {str(self.ne)} does not match what is read from file {str(ne)}")
+
+ if self.dtype == 'f':
+ assert(dtype == 0)
+ else:
+ raise ValueError(f"Unhandled data type '{self.dtype}'")
+
+ self.name = bytes(data[offset:offset+namelen]); offset += namelen
+ # 32-byte alignment
+ offset += (0 - offset) & 31
+ self.data = data[offset:offset+self.nbytes]
+ offset += self.nbytes
+ return offset
+
+ def max_storage_size(self):
+ result = 0
+ result += 4 # nd
+ result += 4 # namelen
+ result += 4 # dtype
+ result += len(self.ne)*8 # ne
+ result += 48 # name (maximum as of commit 3b5515bbe0e2224425986ba24f1f5d84aa38dce9)
+ result += 31 # 32-byte alignment
+ result += self.nbytes
+ return result
+
+ def save_gguf(self, gguf_writer, name):
+ gguf_writer.add_tensor(
+ name=name,
+ tensor=self.data,
+ raw_shape=np.array(list(reversed(self.ne))),
+ raw_dtype=gguf.GGMLQuantizationType.F32)
+
+class OptimizationContext:
+ def __init__(self):
+ pass
+
+ def load(self, data, offset):
+ self.version = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]
+ offset += 4
+
+ if self.version != 1:
+ raise ValueError('Invalid version of optimization context in checkpoint file')
+
+ self.past = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.lbfgs_m = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.nx = struct.unpack('N', bytes(data[offset:offset + 8]))[0]; offset += 8
+ self.iter = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.just_initialized = bool(struct.unpack('<i', bytes(data[offset:offset + 4]))[0]); offset += 4
+
+ self.adam_m = Tensor('f', [self.nx])
+ self.adam_v = Tensor('f', [self.nx])
+ self.adam_pf = Tensor('f', [self.past] if self.past > 0 else [])
+
+ self.lbfgs_x = Tensor('f', [self.nx])
+ self.lbfgs_xp = Tensor('f', [self.nx])
+ self.lbfgs_g = Tensor('f', [self.nx])
+ self.lbfgs_gp = Tensor('f', [self.nx])
+ self.lbfgs_d = Tensor('f', [self.nx])
+ self.lbfgs_pf = Tensor('f', [self.past] if self.past > 0 else [])
+ self.lbfgs_lmal = Tensor('f', [self.lbfgs_m])
+ self.lbfgs_lmys = Tensor('f', [self.lbfgs_m])
+ self.lbfgs_lms = Tensor('f', [self.nx, self.lbfgs_m])
+ self.lbfgs_lmy = Tensor('f', [self.nx, self.lbfgs_m])
+
+ # forgot to save type in version 1:
+ # guess self.type from number of remaining bytes
+ size_type_0 = 12 + sum([t.max_storage_size() for t in
+ [self.adam_m, self.adam_v]
+ +([self.adam_pf] if (self.past > 0) else [])])
+ size_type_1 = 24 + sum([t.max_storage_size() for t in
+ [self.lbfgs_x, self.lbfgs_xp, self.lbfgs_g,
+ self.lbfgs_gp, self.lbfgs_d, self.lbfgs_pf,
+ self.lbfgs_lmal, self.lbfgs_lmys,
+ self.lbfgs_lms, self.lbfgs_lmy]
+ +([self.lbfgs_pf] if (self.past > 0) else [])])
+ # due to alignment padding the size might not by exact
+ # but the difference in size for both types is significant,
+ # so we can just use whichever is closest
+ remaining = len(data) - offset
+ if abs(remaining - size_type_0) < abs(remaining - size_type_1):
+ self.type = 0
+ else:
+ self.type = 1
+
+ if self.type == 0:
+ offset = self.adam_m.load(data, offset)
+ offset = self.adam_v.load(data, offset)
+ offset = self.adam_pf.load(data,offset)
+
+ self.adam_fx_best = struct.unpack('<f', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.adam_fx_prev = struct.unpack('<f', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.adam_n_no_improvement = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4
+
+ elif self.type == 1:
+ offset = self.lbfgs_x.load(data, offset)
+ offset = self.lbfgs_xp.load(data, offset)
+ offset = self.lbfgs_g.load(data, offset)
+ offset = self.lbfgs_gp.load(data, offset)
+ offset = self.lbfgs_d.load(data, offset)
+ offset = self.lbfgs_pf.load(data, offset)
+ offset = self.lbfgs_lmal.load(data, offset)
+ offset = self.lbfgs_lmys.load(data, offset)
+ offset = self.lbfgs_lms.load(data, offset)
+ offset = self.lbfgs_lmy.load(data, offset)
+
+ self.lbfgs_fx_best = struct.unpack('<f', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.lbfgs_step = struct.unpack('<f', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.lbfgs_j = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.lbfgs_k = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.lbfgs_end = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.lbfgs_n_no_improvement = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4
+
+ else:
+ raise ValueError(f"Invalid optimizer type '{self.type}'")
+
+ return offset
+
+ def save_gguf(self, gguf_writer):
+ gguf_writer.add_uint32(LLM_KV_OPTIMIZER_FILE_VERSION, 0)
+ gguf_writer.add_uint32(LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, self.past)
+ gguf_writer.add_uint64(LLM_KV_OPTIMIZER_PARAMETER_COUNT, self.nx)
+ gguf_writer.add_uint32(LLM_KV_OPTIMIZER_ITERATION_COUNT, self.iter)
+ gguf_writer.add_bool(LLM_KV_OPTIMIZER_JUST_INITIALIZED, self.just_initialized)
+
+ if self.type == 0:
+ gguf_writer.add_string(LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM)
+ gguf_writer.add_float32(LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, self.adam_fx_best)
+ gguf_writer.add_float32(LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, self.adam_fx_prev)
+ gguf_writer.add_uint32(LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, self.adam_n_no_improvement)
+
+ self.adam_m.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS)
+ self.adam_v.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS)
+ if self.past > 0:
+ self.adam_pf.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES)
+
+ elif self.type == 1:
+ gguf_writer.add_string(LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS)
+ gguf_writer.add_uint32(LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, self.lbfgs_m)
+ gguf_writer.add_float32(LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, self.lbfgs_fx_best)
+ gguf_writer.add_float32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, self.lbfgs_step)
+ gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, self.lbfgs_j)
+ gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, self.lbfgs_k)
+ gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, self.lbfgs_end)
+ gguf_writer.add_uint32(LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, self.lbfgs_n_no_improvement)
+
+ self.lbfgs_x.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS)
+ self.lbfgs_xp.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS)
+ self.lbfgs_g.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS)
+ self.lbfgs_gp.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS)
+ self.lbfgs_d.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION)
+ if self.past > 0:
+ self.lbfgs_pf.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES)
+ self.lbfgs_lmal.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA)
+ self.lbfgs_lmys.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS)
+ self.lbfgs_lms.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S)
+ self.lbfgs_lmy.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y)
+ else:
+ raise ValueError('Unknown optimizer type')
+
+class LoraParams:
+ def __init__(self):
+ pass
+
+ def load(self, data, offset):
+ self.n_rank_attention_norm = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_rank_wq = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_rank_wk = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_rank_wv = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_rank_wo = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_rank_ffn_norm = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_rank_w1 = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_rank_w2 = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_rank_w3 = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_rank_tok_embeddings = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_rank_norm = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_rank_output = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ return offset
+
+ def save_gguf(self, gguf_writer):
+ gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD, self.n_rank_tok_embeddings)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM, self.n_rank_norm)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_OUTPUT, self.n_rank_output)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_ATTN_NORM, self.n_rank_attention_norm)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_ATTN_Q, self.n_rank_wq)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_ATTN_K, self.n_rank_wk)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_ATTN_V, self.n_rank_wv)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_ATTN_OUT, self.n_rank_wo)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_FFN_NORM, self.n_rank_ffn_norm)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_FFN_GATE, self.n_rank_w1)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_FFN_DOWN, self.n_rank_w2)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_LORA_RANK_FFN_UP, self.n_rank_w3)
+
+class ModelParams:
+ def __init__(self, n_ff = None):
+ self.n_ff = n_ff
+
+ def load(self, data, offset):
+ self.n_vocab = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_embd = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_mult = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_head = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_layer = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.n_rot = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ return offset
+
+ def get_n_ff(self):
+ if self.n_ff is None:
+ # struct my_llama_model::get_n_ff in train-text-from-scratch.cpp commit 3b5515bbe0e2224425986ba24f1f5d84aa38dce9
+ return ((2*(4*self.n_embd)//3 + self.n_mult - 1)//self.n_mult)*self.n_mult
+ else:
+ return self.n_ff
+
+ def save_gguf(self, gguf_writer):
+ # self.n_vocab not saved
+ gguf_writer.add_embedding_length(self.n_embd)
+ gguf_writer.add_head_count(self.n_head)
+ gguf_writer.add_block_count(self.n_layer)
+ gguf_writer.add_rope_dimension_count(self.n_rot)
+ gguf_writer.add_feed_forward_length(self.get_n_ff())
+
+def tensor_name(key, bid=None, suffix=".weight"):
+ return gguf.MODEL_TENSOR_NAMES[gguf.MODEL_ARCH.LLAMA][key].format(bid=bid) + suffix
+
+class Layer:
+ def __init__(self, params, lora_params, bid):
+ self.bid = bid
+ self.att_norm_a = Tensor('f', [lora_params.n_rank_attention_norm, params.n_embd])
+ self.att_norm_b = Tensor('f', [lora_params.n_rank_attention_norm, 1])
+ self.wq_a = Tensor('f', [lora_params.n_rank_wq, params.n_embd])
+ self.wq_b = Tensor('f', [lora_params.n_rank_wq, params.n_embd])
+ self.wk_a = Tensor('f', [lora_params.n_rank_wk, params.n_embd])
+ self.wk_b = Tensor('f', [lora_params.n_rank_wk, params.n_embd])
+ self.wv_a = Tensor('f', [lora_params.n_rank_wv, params.n_embd])
+ self.wv_b = Tensor('f', [lora_params.n_rank_wv, params.n_embd])
+ self.wo_a = Tensor('f', [lora_params.n_rank_wo, params.n_embd])
+ self.wo_b = Tensor('f', [lora_params.n_rank_wo, params.n_embd])
+ self.ffn_norm_a = Tensor('f', [lora_params.n_rank_ffn_norm, params.n_embd])
+ self.ffn_norm_b = Tensor('f', [lora_params.n_rank_ffn_norm, 1])
+ self.w1_a = Tensor('f', [lora_params.n_rank_w1, params.n_embd])
+ self.w1_b = Tensor('f', [lora_params.n_rank_w1, params.get_n_ff()])
+ self.w2_a = Tensor('f', [lora_params.n_rank_w2, params.get_n_ff()])
+ self.w2_b = Tensor('f', [lora_params.n_rank_w2, params.n_embd])
+ self.w3_a = Tensor('f', [lora_params.n_rank_w3, params.n_embd])
+ self.w3_b = Tensor('f', [lora_params.n_rank_w3, params.get_n_ff()])
+
+ def load(self, data, offset):
+ offset = self.att_norm_a.load(data, offset)
+ offset = self.att_norm_b.load(data, offset)
+ offset = self.wq_a.load(data, offset)
+ offset = self.wq_b.load(data, offset)
+ offset = self.wk_a.load(data, offset)
+ offset = self.wk_b.load(data, offset)
+ offset = self.wv_a.load(data, offset)
+ offset = self.wv_b.load(data, offset)
+ offset = self.wo_a.load(data, offset)
+ offset = self.wo_b.load(data, offset)
+ offset = self.ffn_norm_a.load(data, offset)
+ offset = self.ffn_norm_b.load(data, offset)
+ offset = self.w1_a.load(data, offset)
+ offset = self.w1_b.load(data, offset)
+ offset = self.w2_a.load(data, offset)
+ offset = self.w2_b.load(data, offset)
+ offset = self.w3_a.load(data, offset)
+ offset = self.w3_b.load(data, offset)
+ return offset
+
+ def save_gguf(self, gguf_writer):
+ self.att_norm_a.save_gguf(gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_NORM, self.bid, ".weight.lora_a"))
+ self.att_norm_b.save_gguf(gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_NORM, self.bid, ".weight.lora_b"))
+ self.wq_a.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_Q, self.bid, ".weight.lora_a"))
+ self.wq_b.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_Q, self.bid, ".weight.lora_b"))
+ self.wk_a.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_K, self.bid, ".weight.lora_a"))
+ self.wk_b.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_K, self.bid, ".weight.lora_b"))
+ self.wv_a.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_V, self.bid, ".weight.lora_a"))
+ self.wv_b.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_V, self.bid, ".weight.lora_b"))
+ self.wo_a.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, self.bid, ".weight.lora_a"))
+ self.wo_b.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, self.bid, ".weight.lora_b"))
+ self.ffn_norm_a.save_gguf(gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_NORM, self.bid, ".weight.lora_a"))
+ self.ffn_norm_b.save_gguf(gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_NORM, self.bid, ".weight.lora_b"))
+ self.w1_a.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_GATE, self.bid, ".weight.lora_a"))
+ self.w1_b.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_GATE, self.bid, ".weight.lora_b"))
+ self.w2_a.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, self.bid, ".weight.lora_a"))
+ self.w2_b.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, self.bid, ".weight.lora_b"))
+ self.w3_a.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_UP, self.bid, ".weight.lora_a"))
+ self.w3_b.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.FFN_UP, self.bid, ".weight.lora_b"))
+
+class LoraModel:
+ def __init__(self, n_ff = None):
+ self.params = ModelParams(n_ff = n_ff)
+ self.lora_params = LoraParams()
+ self.layers = []
+
+ def load(self, data, offset):
+ offset = self.params.load(data, offset)
+ offset = self.lora_params.load(data, offset)
+
+ self.tok_embd_a = Tensor('f', [self.lora_params.n_rank_tok_embeddings, self.params.n_embd])
+ self.tok_embd_b = Tensor('f', [self.lora_params.n_rank_tok_embeddings, self.params.n_vocab])
+ self.norm_a = Tensor('f', [self.lora_params.n_rank_norm, self.params.n_embd])
+ self.norm_b = Tensor('f', [self.lora_params.n_rank_norm, 1])
+ self.output_a = Tensor('f', [self.lora_params.n_rank_output, self.params.n_embd])
+ self.output_b = Tensor('f', [self.lora_params.n_rank_output, self.params.n_vocab])
+
+ offset = self.tok_embd_a.load(data, offset)
+ offset = self.tok_embd_b.load(data, offset)
+ offset = self.norm_a.load(data, offset)
+ offset = self.norm_b.load(data, offset)
+ offset = self.output_a.load(data, offset)
+ offset = self.output_b.load(data, offset)
+
+ self.layers.clear()
+ for bid in range(self.params.n_layer):
+ layer = Layer(self.params, self.lora_params, bid)
+ offset = layer.load(data, offset)
+ self.layers.append(layer)
+
+ return offset
+
+ def save_gguf(self, gguf_writer):
+ self.params.save_gguf(gguf_writer)
+ self.lora_params.save_gguf(gguf_writer)
+
+ self.tok_embd_a.save_gguf(gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD, suffix=".weight.lora_a"))
+ self.tok_embd_b.save_gguf(gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD, suffix=".weight.lora_b"))
+ self.norm_a.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.OUTPUT_NORM, suffix=".weight.lora_a"))
+ self.norm_b.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.OUTPUT_NORM, suffix=".weight.lora_b"))
+ self.output_a.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.OUTPUT, suffix=".weight.lora_a"))
+ self.output_b.save_gguf (gguf_writer, name=tensor_name(gguf.MODEL_TENSOR.OUTPUT, suffix=".weight.lora_b"))
+
+ for layer in self.layers:
+ layer.save_gguf(gguf_writer)
+
+class LoraCheckpoint:
+ def __init__(self, n_ff = None):
+ self.model = LoraModel(n_ff = n_ff)
+ self.opt_ctx = OptimizationContext()
+
+ def load(self, data, offset):
+ magic = bytes(reversed(data[offset:offset + 4])); offset += 4
+ if magic != b'ggcl':
+ raise ValueError(f"File header magic indicates, that this is no finetune-lora checkpoint file. Expected 'ggcl', Got '{str(magic)}'")
+
+ self.version = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ if self.version != 0:
+ raise ValueError('Invalid version of checkpoint file')
+
+ self.train_its = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.train_samples = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+ self.train_tokens = struct.unpack('<I', bytes(data[offset:offset + 4]))[0]; offset += 4
+
+ offset = self.model.load(data, offset)
+ offset = self.opt_ctx.load(data, offset)
+
+ return offset
+
+ def save_gguf(self, gguf_writer):
+ gguf_writer.add_file_type(gguf.GGMLQuantizationType.F32)
+ gguf_writer.add_layer_norm_rms_eps(1e-5)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_FILE_VERSION, 0)
+ gguf_writer.add_string(LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_ITERATION_COUNT, self.train_its)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_SAMPLE_COUNT, self.train_samples)
+ gguf_writer.add_uint32(LLM_KV_TRAINING_TOKEN_COUNT, self.train_tokens)
+ self.model.save_gguf(gguf_writer)
+ self.opt_ctx.save_gguf(gguf_writer)
+
+def handle_args():
+ parser = argparse.ArgumentParser(description = 'Convert finetune checkpoints to GGUF')
+ parser.add_argument('--input', '-i', type = Path, help = 'Input finetune checkpoint filename', required=True)
+ parser.add_argument('--output', '-o', type = Path, help = 'Output GGUF filename', required=True)
+ parser.add_argument('--ff', type = int, help = "Feedforward size, if not provided compute from n_mult. Provide this if you get 'ValueError: Tensor.load: Expected number of elements does not match what is read from file'", required=False)
+ return parser.parse_args()
+
+def main():
+ cfg = handle_args()
+ print(cfg)
+ data = np.memmap(cfg.input, mode = 'r')
+ chk = LoraCheckpoint(n_ff = cfg.ff)
+ offset = 0
+ offset = chk.load(data, offset)
+ # we should have read all available data
+ assert(offset == len(data))
+
+ gguf_writer = gguf.GGUFWriter(cfg.output, gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA], use_temp_file = False)
+ chk.save_gguf(gguf_writer)
+ print(" gguf: write header")
+ gguf_writer.write_header_to_file()
+ print(" gguf: write metadata")
+ gguf_writer.write_kv_data_to_file()
+ print(" gguf: write tensors")
+ gguf_writer.write_tensors_to_file()
+ gguf_writer.close()
+
+if __name__ == '__main__':
+ main()
--- /dev/null
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "llama.h"
+#include "common.h"
+#include "train.h"
+#include <unordered_map>
+#include <vector>
+#include <cassert>
+#include <climits>
+#include <cstring>
+#include <cstdarg>
+#include <ctime>
+#include <random>
+#include <stdexcept>
+#include <algorithm>
+#include <string>
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+static const size_t tensor_alignment = 32;
+
+struct my_llama_hparams {
+ uint32_t n_vocab = 32000;
+ uint32_t n_ctx = 512;
+ uint32_t n_embd = 4096;
+ uint32_t n_ff = 11008;
+ uint32_t n_head = 32;
+ uint32_t n_head_kv = 32;
+ uint32_t n_layer = 32;
+
+ // float f_norm_eps = 1e-5f; // falcon
+ float f_norm_rms_eps = 1e-5f; // llama
+
+ float rope_freq_base = 10000.0f;
+ float rope_freq_scale = 1.0f;
+
+ uint32_t n_gqa() const {
+ return n_head/n_head_kv;
+ }
+
+ uint32_t n_embd_head() const {
+ return n_embd/n_head;
+ }
+
+ uint32_t n_embd_gqa() const {
+ return n_embd/n_gqa();
+ }
+
+ bool operator!=(const my_llama_hparams& other) const {
+ return memcmp(this, &other, sizeof(other));
+ }
+};
+
+struct my_llama_layer {
+ // normalization
+ struct ggml_tensor * attention_norm;
+
+ // attention
+ struct ggml_tensor * wq;
+ struct ggml_tensor * wk;
+ struct ggml_tensor * wv;
+ struct ggml_tensor * wo;
+
+ // normalization
+ struct ggml_tensor * ffn_norm;
+
+ // ff
+ struct ggml_tensor * w1;
+ struct ggml_tensor * w2;
+ struct ggml_tensor * w3;
+};
+
+struct my_llama_model {
+ struct my_llama_hparams hparams;
+
+ struct ggml_tensor * tok_embeddings;
+
+ struct ggml_tensor * norm;
+ struct ggml_tensor * output;
+
+ std::vector<my_llama_layer> layers;
+};
+
+struct my_llama_lora_hparams {
+ uint32_t lora_r = 1;
+ uint32_t lora_alpha = 1;
+ uint32_t n_rank_attention_norm = 1;
+ uint32_t n_rank_wq = 4;
+ uint32_t n_rank_wk = 4;
+ uint32_t n_rank_wv = 4;
+ uint32_t n_rank_wo = 4;
+ uint32_t n_rank_ffn_norm = 1;
+ uint32_t n_rank_w1 = 4;
+ uint32_t n_rank_w2 = 4;
+ uint32_t n_rank_w3 = 4;
+ uint32_t n_rank_tok_embeddings = 4;
+ uint32_t n_rank_norm = 1;
+ uint32_t n_rank_output = 4;
+
+ bool operator!=(const my_llama_lora_hparams& other) const {
+ return memcmp(this, &other, sizeof(other));
+ }
+};
+
+struct my_llama_lora_layer {
+ // normalization
+ struct ggml_tensor * attention_norm_a;
+ struct ggml_tensor * attention_norm_b;
+
+ // attention
+ struct ggml_tensor * wq_a;
+ struct ggml_tensor * wq_b;
+ struct ggml_tensor * wk_a;
+ struct ggml_tensor * wk_b;
+ struct ggml_tensor * wv_a;
+ struct ggml_tensor * wv_b;
+ struct ggml_tensor * wo_a;
+ struct ggml_tensor * wo_b;
+
+ // normalization
+ struct ggml_tensor * ffn_norm_a;
+ struct ggml_tensor * ffn_norm_b;
+
+ // ff
+ struct ggml_tensor * w1_a;
+ struct ggml_tensor * w1_b;
+ struct ggml_tensor * w2_a;
+ struct ggml_tensor * w2_b;
+ struct ggml_tensor * w3_a;
+ struct ggml_tensor * w3_b;
+};
+
+struct my_llama_lora {
+ struct ggml_context * ctx = NULL;
+ std::vector<uint8_t> data;
+
+ my_llama_lora_hparams hparams;
+
+ struct ggml_tensor * tok_embeddings_a;
+ struct ggml_tensor * tok_embeddings_b;
+
+ struct ggml_tensor * norm_a;
+ struct ggml_tensor * norm_b;
+ struct ggml_tensor * output_a;
+ struct ggml_tensor * output_b;
+
+ std::vector<my_llama_lora_layer> layers;
+};
+
+// gguf constants
+static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
+static const char * LLM_KV_TRAINING_TYPE = "training.type";
+
+static const char * LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD = "training.lora.rank.token_embd";
+static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm";
+static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT = "training.lora.rank.output";
+static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_NORM = "training.lora.rank.attn_norm";
+static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_Q = "training.lora.rank.attn_q";
+static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_K = "training.lora.rank.attn_k";
+static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_V = "training.lora.rank.attn_v";
+static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_OUT = "training.lora.rank.attn_output";
+static const char * LLM_KV_TRAINING_LORA_RANK_FFN_NORM = "training.lora.rank.ffn_norm";
+static const char * LLM_KV_TRAINING_LORA_RANK_FFN_GATE = "training.lora.rank.ffn_gate";
+static const char * LLM_KV_TRAINING_LORA_RANK_FFN_DOWN = "training.lora.rank.ffn_down";
+static const char * LLM_KV_TRAINING_LORA_RANK_FFN_UP = "training.lora.rank.ffn_up";
+
+// gguf constants (sync with gguf.py)
+
+static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture";
+static const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type";
+
+static const char * LLM_KV_CONTEXT_LENGTH = "%s.context_length";
+static const char * LLM_KV_EMBEDDING_LENGTH = "%s.embedding_length";
+static const char * LLM_KV_BLOCK_COUNT = "%s.block_count";
+static const char * LLM_KV_FEED_FORWARD_LENGTH = "%s.feed_forward_length";
+static const char * LLM_KV_ATTENTION_HEAD_COUNT = "%s.attention.head_count";
+static const char * LLM_KV_ATTENTION_HEAD_COUNT_KV = "%s.attention.head_count_kv";
+static const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon";
+static const char * LLM_KV_ROPE_DIMENSION_COUNT = "%s.rope.dimension_count";
+static const char * LLM_KV_ROPE_FREQ_BASE = "%s.rope.freq_base"; // TODO load in llama.cpp
+static const char * LLM_KV_ROPE_SCALE_LINEAR = "%s.rope.scale_linear";
+
+static const char * LLM_TENSOR_TOKEN_EMBD = "token_embd";
+static const char * LLM_TENSOR_OUTPUT_NORM = "output_norm";
+static const char * LLM_TENSOR_OUTPUT = "output";
+static const char * LLM_TENSOR_ATTN_NORM = "blk.%d.attn_norm";
+static const char * LLM_TENSOR_ATTN_Q = "blk.%d.attn_q";
+static const char * LLM_TENSOR_ATTN_K = "blk.%d.attn_k";
+static const char * LLM_TENSOR_ATTN_V = "blk.%d.attn_v";
+static const char * LLM_TENSOR_ATTN_OUT = "blk.%d.attn_output";
+static const char * LLM_TENSOR_FFN_NORM = "blk.%d.ffn_norm";
+static const char * LLM_TENSOR_FFN_GATE = "blk.%d.ffn_gate";
+static const char * LLM_TENSOR_FFN_DOWN = "blk.%d.ffn_down";
+static const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up";
+
+static void print_params(struct my_llama_hparams * params) {
+ printf("%s: n_vocab: %u\n", __func__, params->n_vocab);
+ printf("%s: n_ctx: %u\n", __func__, params->n_ctx);
+ printf("%s: n_embd: %u\n", __func__, params->n_embd);
+ printf("%s: n_ff: %u\n", __func__, params->n_ff);
+ printf("%s: n_head: %u\n", __func__, params->n_head);
+ printf("%s: n_head_kv: %u\n", __func__, params->n_head_kv);
+ printf("%s: n_layer: %u\n", __func__, params->n_layer);
+ printf("%s: norm_rms_eps : %f\n", __func__, params->f_norm_rms_eps);
+ printf("%s: rope_freq_base : %f\n", __func__, params->rope_freq_base);
+ printf("%s: rope_freq_scale : %f\n", __func__, params->rope_freq_scale);
+}
+
+static void print_lora_params(struct my_llama_lora_hparams * params) {
+ printf("%s: n_rank_attention_norm : %u\n", __func__, params->n_rank_attention_norm);
+ printf("%s: n_rank_wq : %u\n", __func__, params->n_rank_wq);
+ printf("%s: n_rank_wk : %u\n", __func__, params->n_rank_wk);
+ printf("%s: n_rank_wv : %u\n", __func__, params->n_rank_wv);
+ printf("%s: n_rank_wo : %u\n", __func__, params->n_rank_wo);
+ printf("%s: n_rank_ffn_norm : %u\n", __func__, params->n_rank_ffn_norm);
+ printf("%s: n_rank_w1 : %u\n", __func__, params->n_rank_w1);
+ printf("%s: n_rank_w2 : %u\n", __func__, params->n_rank_w2);
+ printf("%s: n_rank_w3 : %u\n", __func__, params->n_rank_w3);
+ printf("%s: n_rank_tok_embeddings : %u\n", __func__, params->n_rank_tok_embeddings);
+ printf("%s: n_rank_norm : %u\n", __func__, params->n_rank_norm);
+ printf("%s: n_rank_output : %u\n", __func__, params->n_rank_output);
+}
+
+#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
+{ \
+ const std::string skey(key); \
+ const int kid = gguf_find_key(ctx, skey.c_str()); \
+ if (kid >= 0) { \
+ enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
+ if (ktype != (type)) { \
+ die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
+ } \
+ (dst) = func(ctx, kid); \
+ } else if (req) { \
+ die_fmt("key not found in model: %s", skey.c_str()); \
+ } \
+}
+
+static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_hparams * hparams, const char * expected_arch) {
+ std::string arch;
+
+ GGUF_GET_KEY(ctx, arch, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_GENERAL_ARCHITECTURE);
+ if (expected_arch != NULL) {
+ if (arch != expected_arch) {
+ printf("%s: arch=%s expected_arch=%s\n", __func__, arch.c_str(), expected_arch);
+ }
+ GGML_ASSERT(arch == expected_arch);
+ }
+
+ std::vector<char> keybuf;
+ keybuf.resize(512);
+ auto kv = [&arch, &keybuf](const char * key) -> const char * {
+ snprintf(keybuf.data(), keybuf.size(), key, arch.c_str());
+ return keybuf.data();
+ };
+
+ GGUF_GET_KEY(ctx, hparams->n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
+ GGUF_GET_KEY(ctx, hparams->n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH));
+ GGUF_GET_KEY(ctx, hparams->n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
+ GGUF_GET_KEY(ctx, hparams->n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
+ GGUF_GET_KEY(ctx, hparams->n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
+
+ // n_head_kv is optional, default to n_head
+ hparams->n_head_kv = hparams->n_head;
+ GGUF_GET_KEY(ctx, hparams->n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
+
+ float rope_freq_scale = 1.0f;
+ GGUF_GET_KEY(ctx, hparams->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
+ GGUF_GET_KEY(ctx, hparams->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
+ GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
+ if (rope_freq_scale != 1.0f) {
+ hparams->rope_freq_scale = 1.0f / rope_freq_scale;
+ }
+}
+
+static void init_model(struct llama_model * input, struct my_llama_model * model, const char * fn_model, uint32_t n_ctx) {
+ auto & hparams = model->hparams;
+
+ std::vector<char> tn_buf;
+ tn_buf.resize(GGML_MAX_NAME);
+ auto tn = [&tn_buf](const char * key) -> const char * {
+ snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key);
+ return tn_buf.data();
+ };
+ auto tni = [&tn_buf](const char * key, int bid) -> const char * {
+ snprintf(tn_buf.data(), tn_buf.size(), key, bid);
+ std::string s = tn_buf.data();
+ snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str());
+ return tn_buf.data();
+ };
+
+
+ // get parameters directly from gguf file
+ {
+ struct gguf_init_params params = {
+ /*.no_alloc = */ false,
+ /*.ctx = */ NULL,
+ };
+ struct gguf_context * mctx = gguf_init_from_file(fn_model, params);
+
+ load_model_hparams_gguf(mctx, &hparams, "llama");
+
+ gguf_free(mctx);
+ }
+ hparams.n_vocab = llama_model_n_vocab(input);
+ hparams.n_ctx = n_ctx;
+
+ // get tensors from llama_model (possibly mmapped)
+ model->tok_embeddings = llama_get_model_tensor(input, tn(LLM_TENSOR_TOKEN_EMBD));
+ model->norm = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT_NORM));
+ model->output = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT));
+
+ assert_shape_2d(model->tok_embeddings, hparams.n_embd, hparams.n_vocab);
+ assert_shape_1d(model->norm, hparams.n_embd);
+ assert_shape_2d(model->output, hparams.n_embd, hparams.n_vocab);
+
+ model->layers.resize(hparams.n_layer);
+ for (uint32_t i = 0; i < hparams.n_layer; ++i) {
+ auto & layer = model->layers[i];
+
+ layer.attention_norm = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_NORM, i));
+ layer.wq = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_Q, i));
+ layer.wk = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_K, i));
+ layer.wv = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_V, i));
+ layer.wo = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_OUT, i));
+ layer.ffn_norm = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_NORM, i));
+ layer.w1 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_GATE, i));
+ layer.w2 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_DOWN, i));
+ layer.w3 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_UP, i));
+
+ assert_shape_1d(layer.attention_norm, hparams.n_embd);
+ assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd);
+ assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd);
+ assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd);
+ assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd);
+ assert_shape_1d(layer.ffn_norm, hparams.n_embd);
+ assert_shape_2d(layer.w1, hparams.n_embd, hparams.n_ff);
+ assert_shape_2d(layer.w2, hparams.n_ff, hparams.n_embd);
+ assert_shape_2d(layer.w3, hparams.n_embd, hparams.n_ff);
+ }
+}
+
+static void set_param_lora(struct my_llama_lora * lora) {
+ const uint32_t n_layer = lora->layers.size();
+
+ struct ggml_context* ctx = lora->ctx;
+
+ ggml_set_param(ctx, lora->tok_embeddings_a);
+ ggml_set_param(ctx, lora->tok_embeddings_b);
+ ggml_set_param(ctx, lora->norm_a);
+ ggml_set_param(ctx, lora->norm_b);
+ ggml_set_param(ctx, lora->output_a);
+ ggml_set_param(ctx, lora->output_b);
+
+ for (uint32_t i = 0; i < n_layer; ++i) {
+ auto & layer = lora->layers[i];
+
+ ggml_set_param(ctx, layer.attention_norm_a);
+ ggml_set_param(ctx, layer.attention_norm_b);
+ ggml_set_param(ctx, layer.wq_a);
+ ggml_set_param(ctx, layer.wq_b);
+ ggml_set_param(ctx, layer.wk_a);
+ ggml_set_param(ctx, layer.wk_b);
+ ggml_set_param(ctx, layer.wv_a);
+ ggml_set_param(ctx, layer.wv_b);
+ ggml_set_param(ctx, layer.wo_a);
+ ggml_set_param(ctx, layer.wo_b);
+ ggml_set_param(ctx, layer.ffn_norm_a);
+ ggml_set_param(ctx, layer.ffn_norm_b);
+ ggml_set_param(ctx, layer.w1_a);
+ ggml_set_param(ctx, layer.w1_b);
+ ggml_set_param(ctx, layer.w2_a);
+ ggml_set_param(ctx, layer.w2_b);
+ ggml_set_param(ctx, layer.w3_a);
+ ggml_set_param(ctx, layer.w3_b);
+ }
+}
+
+static void alloc_lora(struct ggml_allocr * alloc, struct my_llama_lora * lora) {
+ ggml_allocr_alloc(alloc, lora->tok_embeddings_a);
+ ggml_allocr_alloc(alloc, lora->tok_embeddings_b);
+ ggml_allocr_alloc(alloc, lora->norm_a);
+ ggml_allocr_alloc(alloc, lora->norm_b);
+ ggml_allocr_alloc(alloc, lora->output_a);
+ ggml_allocr_alloc(alloc, lora->output_b);
+ for (uint32_t i = 0; i < lora->layers.size(); ++i) {
+ auto & layer = lora->layers[i];
+ ggml_allocr_alloc(alloc, layer.attention_norm_a);
+ ggml_allocr_alloc(alloc, layer.attention_norm_b);
+ ggml_allocr_alloc(alloc, layer.wq_a);
+ ggml_allocr_alloc(alloc, layer.wq_b);
+ ggml_allocr_alloc(alloc, layer.wk_a);
+ ggml_allocr_alloc(alloc, layer.wk_b);
+ ggml_allocr_alloc(alloc, layer.wv_a);
+ ggml_allocr_alloc(alloc, layer.wv_b);
+ ggml_allocr_alloc(alloc, layer.wo_a);
+ ggml_allocr_alloc(alloc, layer.wo_b);
+ ggml_allocr_alloc(alloc, layer.ffn_norm_a);
+ ggml_allocr_alloc(alloc, layer.ffn_norm_b);
+ ggml_allocr_alloc(alloc, layer.w1_a);
+ ggml_allocr_alloc(alloc, layer.w1_b);
+ ggml_allocr_alloc(alloc, layer.w2_a);
+ ggml_allocr_alloc(alloc, layer.w2_b);
+ ggml_allocr_alloc(alloc, layer.w3_a);
+ ggml_allocr_alloc(alloc, layer.w3_b);
+ }
+ ggml_allocr_alloc(alloc, lora->tok_embeddings_a->grad);
+ ggml_allocr_alloc(alloc, lora->tok_embeddings_b->grad);
+ ggml_allocr_alloc(alloc, lora->norm_a->grad);
+ ggml_allocr_alloc(alloc, lora->norm_b->grad);
+ ggml_allocr_alloc(alloc, lora->output_a->grad);
+ ggml_allocr_alloc(alloc, lora->output_b->grad);
+ for (uint32_t i = 0; i < lora->layers.size(); ++i) {
+ auto & layer = lora->layers[i];
+ ggml_allocr_alloc(alloc, layer.attention_norm_a->grad);
+ ggml_allocr_alloc(alloc, layer.attention_norm_b->grad);
+ ggml_allocr_alloc(alloc, layer.wq_a->grad);
+ ggml_allocr_alloc(alloc, layer.wq_b->grad);
+ ggml_allocr_alloc(alloc, layer.wk_a->grad);
+ ggml_allocr_alloc(alloc, layer.wk_b->grad);
+ ggml_allocr_alloc(alloc, layer.wv_a->grad);
+ ggml_allocr_alloc(alloc, layer.wv_b->grad);
+ ggml_allocr_alloc(alloc, layer.wo_a->grad);
+ ggml_allocr_alloc(alloc, layer.wo_b->grad);
+ ggml_allocr_alloc(alloc, layer.ffn_norm_a->grad);
+ ggml_allocr_alloc(alloc, layer.ffn_norm_b->grad);
+ ggml_allocr_alloc(alloc, layer.w1_a->grad);
+ ggml_allocr_alloc(alloc, layer.w1_b->grad);
+ ggml_allocr_alloc(alloc, layer.w2_a->grad);
+ ggml_allocr_alloc(alloc, layer.w2_b->grad);
+ ggml_allocr_alloc(alloc, layer.w3_a->grad);
+ ggml_allocr_alloc(alloc, layer.w3_b->grad);
+ }
+}
+
+static void init_lora(const struct my_llama_model * model, struct my_llama_lora * lora) {
+ const auto & lparams = lora->hparams;
+
+ const uint32_t n_embd = model->hparams.n_embd;
+ const uint32_t n_embd_gqa = model->hparams.n_embd_gqa();
+ const uint32_t n_layer = model->hparams.n_layer;
+ const uint32_t n_vocab = model->hparams.n_vocab;
+ const uint32_t n_ff = model->hparams.n_ff;
+
+ std::vector<char> tn_buf;
+ tn_buf.resize(GGML_MAX_NAME);
+ auto tn = [&tn_buf](const char * key, const char * suffix) -> const char * {
+ snprintf(tn_buf.data(), tn_buf.size(), "%s%s", key, suffix);
+ return tn_buf.data();
+ };
+ auto tni = [&tn_buf](const char * key, const char * suffix, int bid) -> const char * {
+ snprintf(tn_buf.data(), tn_buf.size(), key, bid);
+ std::string s = tn_buf.data();
+ snprintf(tn_buf.data(), tn_buf.size(), "%s%s", s.c_str(), suffix);
+ return tn_buf.data();
+ };
+
+ // context for lora tensors without their data
+ struct ggml_init_params ctx_lora_params;
+ ctx_lora_params.mem_size = ggml_tensor_overhead()*2*(6 + n_layer*18);
+ ctx_lora_params.mem_buffer = NULL;
+ ctx_lora_params.no_alloc = true;
+
+ struct ggml_context * ctx = ggml_init(ctx_lora_params);
+ lora->ctx = ctx;
+
+ lora->tok_embeddings_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_tok_embeddings, n_embd);
+ lora->tok_embeddings_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_tok_embeddings, n_vocab);
+ lora->norm_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_norm, n_embd);
+ lora->norm_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_norm, 1);
+ lora->output_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_output, n_embd);
+ lora->output_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_output, n_vocab);
+
+ ggml_set_name(lora->tok_embeddings_a, tn(LLM_TENSOR_TOKEN_EMBD, ".weight.lora_a"));
+ ggml_set_name(lora->tok_embeddings_b, tn(LLM_TENSOR_TOKEN_EMBD, ".weight.lora_b"));
+ ggml_set_name(lora->norm_a, tn(LLM_TENSOR_OUTPUT_NORM, ".weight.lora_a"));
+ ggml_set_name(lora->norm_b, tn(LLM_TENSOR_OUTPUT_NORM, ".weight.lora_b"));
+ ggml_set_name(lora->output_a, tn(LLM_TENSOR_OUTPUT, ".weight.lora_a"));
+ ggml_set_name(lora->output_b, tn(LLM_TENSOR_OUTPUT, ".weight.lora_b"));
+
+ lora->layers.resize(n_layer);
+ for (uint32_t i = 0; i < n_layer; ++i) {
+ auto & layer = lora->layers[i];
+
+ layer.attention_norm_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_attention_norm, n_embd);
+ layer.attention_norm_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_attention_norm, 1);
+
+ layer.wq_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wq, n_embd);
+ layer.wq_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wq, n_embd);
+ layer.wk_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wk, n_embd);
+ layer.wk_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wk, n_embd_gqa);
+ layer.wv_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wv, n_embd);
+ layer.wv_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wv, n_embd_gqa);
+ layer.wo_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wo, n_embd);
+ layer.wo_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wo, n_embd);
+
+ layer.ffn_norm_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_norm, n_embd);
+ layer.ffn_norm_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_norm, 1);
+
+ layer.w1_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w1, n_embd);
+ layer.w1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w1, n_ff);
+ layer.w2_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w2, n_ff);
+ layer.w2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w2, n_embd);
+ layer.w3_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w3, n_embd);
+ layer.w3_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w3, n_ff);
+
+ ggml_set_name(layer.attention_norm_a, tni(LLM_TENSOR_ATTN_NORM, ".weight.lora_a", i));
+ ggml_set_name(layer.attention_norm_b, tni(LLM_TENSOR_ATTN_NORM, ".weight.lora_b", i));
+ ggml_set_name(layer.wq_a, tni(LLM_TENSOR_ATTN_Q, ".weight.lora_a", i));
+ ggml_set_name(layer.wq_b, tni(LLM_TENSOR_ATTN_Q, ".weight.lora_b", i));
+ ggml_set_name(layer.wk_a, tni(LLM_TENSOR_ATTN_K, ".weight.lora_a", i));
+ ggml_set_name(layer.wk_b, tni(LLM_TENSOR_ATTN_K, ".weight.lora_b", i));
+ ggml_set_name(layer.wv_a, tni(LLM_TENSOR_ATTN_V, ".weight.lora_a", i));
+ ggml_set_name(layer.wv_b, tni(LLM_TENSOR_ATTN_V, ".weight.lora_b", i));
+ ggml_set_name(layer.wo_a, tni(LLM_TENSOR_ATTN_OUT, ".weight.lora_a", i));
+ ggml_set_name(layer.wo_b, tni(LLM_TENSOR_ATTN_OUT, ".weight.lora_b", i));
+ ggml_set_name(layer.ffn_norm_a, tni(LLM_TENSOR_FFN_NORM, ".weight.lora_a", i));
+ ggml_set_name(layer.ffn_norm_b, tni(LLM_TENSOR_FFN_NORM, ".weight.lora_b", i));
+ ggml_set_name(layer.w1_a, tni(LLM_TENSOR_FFN_GATE, ".weight.lora_a", i));
+ ggml_set_name(layer.w1_b, tni(LLM_TENSOR_FFN_GATE, ".weight.lora_b", i));
+ ggml_set_name(layer.w2_a, tni(LLM_TENSOR_FFN_DOWN, ".weight.lora_a", i));
+ ggml_set_name(layer.w2_b, tni(LLM_TENSOR_FFN_DOWN, ".weight.lora_b", i));
+ ggml_set_name(layer.w3_a, tni(LLM_TENSOR_FFN_UP, ".weight.lora_a", i));
+ ggml_set_name(layer.w3_b, tni(LLM_TENSOR_FFN_UP, ".weight.lora_b", i));
+ }
+
+ set_param_lora(lora);
+
+ // measure data size
+ struct ggml_allocr * alloc = NULL;
+ alloc = ggml_allocr_new_measure(tensor_alignment);
+ alloc_lora(alloc, lora);
+
+ // allocate data
+ lora->data.resize(ggml_allocr_max_size(alloc) + tensor_alignment);
+ ggml_allocr_free(alloc);
+ alloc = ggml_allocr_new(lora->data.data(), lora->data.size(), tensor_alignment);
+ alloc_lora(alloc, lora);
+ ggml_allocr_free(alloc);
+}
+
+static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, float std, float min, float max) {
+ const uint32_t n_layer = lora->layers.size();
+
+ struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
+
+ randomize_tensor_normal(lora->tok_embeddings_a, rnd);
+ randomize_tensor_normal(lora->tok_embeddings_b, rnd);
+ randomize_tensor_normal(lora->norm_a, rnd);
+ randomize_tensor_normal(lora->norm_b, rnd);
+ randomize_tensor_normal(lora->output_a, rnd);
+ randomize_tensor_normal(lora->output_b, rnd);
+
+ for (uint32_t i = 0; i < n_layer; ++i) {
+ auto & layer = lora->layers[i];
+ randomize_tensor_normal(layer.attention_norm_a, rnd);
+ randomize_tensor_normal(layer.attention_norm_b, rnd);
+
+ randomize_tensor_normal(layer.wq_a, rnd);
+ randomize_tensor_normal(layer.wq_b, rnd);
+ randomize_tensor_normal(layer.wk_a, rnd);
+ randomize_tensor_normal(layer.wk_b, rnd);
+ randomize_tensor_normal(layer.wv_a, rnd);
+ randomize_tensor_normal(layer.wv_b, rnd);
+ randomize_tensor_normal(layer.wo_a, rnd);
+ randomize_tensor_normal(layer.wo_b, rnd);
+
+ randomize_tensor_normal(layer.ffn_norm_a, rnd);
+ randomize_tensor_normal(layer.ffn_norm_b, rnd);
+
+ randomize_tensor_normal(layer.w1_a, rnd);
+ randomize_tensor_normal(layer.w1_b, rnd);
+ randomize_tensor_normal(layer.w2_a, rnd);
+ randomize_tensor_normal(layer.w2_b, rnd);
+ randomize_tensor_normal(layer.w3_a, rnd);
+ randomize_tensor_normal(layer.w3_b, rnd);
+ }
+
+ free_random_normal_distribution(rnd);
+}
+
+static struct ggml_tensor * llama_build_lora_finetune_graphs(
+ struct my_llama_model * model,
+ struct my_llama_lora * lora,
+ struct ggml_allocr * alloc,
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ struct ggml_cgraph * gb_tmp,
+ struct ggml_tensor * * logits,
+ struct ggml_tensor * tokens_input,
+ struct ggml_tensor * targets,
+ const int n_tokens,
+ const int n_batch,
+ const bool enable_flash_attn,
+ const bool enable_checkpointing) {
+
+ ggml_set_scratch(ctx, { 0, 0, nullptr, });
+ const int n_past = 0;
+ const int N = n_tokens;
+ const auto & hparams = model->hparams;
+ const int n_ctx = hparams.n_ctx;
+ const int n_vocab = hparams.n_vocab;
+ const int n_embd = hparams.n_embd;
+ const int n_layer = hparams.n_layer;
+ const int n_head = hparams.n_head;
+ const int n_head_kv = hparams.n_head_kv;
+ const int n_ff = hparams.n_ff;
+ const int n_rot = hparams.n_embd_head();
+ const int n_embd_head = hparams.n_embd_head();
+ const int n_embd_gqa = hparams.n_embd_gqa();
+ const float rms_norm_eps = hparams.f_norm_rms_eps;
+ const float rope_freq_base = hparams.rope_freq_base;
+ const float rope_freq_scale = hparams.rope_freq_scale;
+
+ GGML_ASSERT((size_t) n_layer == lora->layers.size());
+
+ auto set_name = [](struct ggml_tensor * t, const char * n) {
+ ggml_set_name(t, n);
+ if (t->grad) {
+ ggml_format_name(t->grad, "%s->grad", n);
+ }
+ };
+
+ // KQ_pos - contains the positions
+ struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
+ {
+ int * data = (int *) KQ_pos->data;
+ for (int i = 0; i < N; ++i) {
+ data[i] = n_past + i;
+ }
+ }
+
+ // rope has so much parameters that we make a custom function for it
+ auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
+ (struct ggml_tensor * t) -> struct ggml_tensor * {
+ // not capturing these, to silcence warnings
+ const int rope_mode = 0;
+
+ return ggml_rope_custom(ctx,
+ t, KQ_pos, n_rot, rope_mode, n_ctx,
+ rope_freq_base, rope_freq_scale);
+ };
+
+ set_name(tokens_input, "tokens_input");
+ set_name(targets, "targets");
+
+ GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
+
+ auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
+ if (ggml_is_quantized(a->type)) {
+ return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
+ } else if (a->type == GGML_TYPE_F32) {
+ return ggml_add(ctx, a, b);
+ } else {
+ die_fmt("%s: Finetuning on tensors with type '%s' is not yet supported.\n",
+ __func__, ggml_type_name(a->type));
+ }
+ };
+
+ struct ggml_tensor * tok_embeddings = add_to_f32(ctx, model->tok_embeddings, ggml_mul_mat(ctx, lora->tok_embeddings_a, lora->tok_embeddings_b));
+ struct ggml_tensor * norm = add_to_f32(ctx, model->norm, ggml_mul_mat(ctx, lora->norm_a, lora->norm_b));
+ struct ggml_tensor * output = add_to_f32(ctx, model->output, ggml_mul_mat(ctx, lora->output_a, lora->output_b));
+
+ struct ggml_tensor * t00 = ggml_reshape_1d(ctx, tokens_input, N*n_batch); set_name(t00, "t00"); assert_shape_1d(t00, N*n_batch);
+ struct ggml_tensor * t01 = ggml_get_rows(ctx, tok_embeddings, t00); set_name(t01, "t01"); assert_shape_2d(t01, n_embd, N*n_batch);
+
+ struct ggml_tensor * cur = t01;
+
+ std::vector<struct ggml_tensor *> checkpoints;
+ if (enable_checkpointing) {
+ checkpoints.push_back(tokens_input);
+ checkpoints.push_back(targets);
+ checkpoints.push_back(t00);
+ checkpoints.push_back(t01);
+ }
+
+ struct ggml_tensor * kv_scale = NULL;
+ if (!enable_flash_attn) {
+ kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
+ }
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct my_llama_layer & layer = model->layers[il];
+ struct my_llama_lora_layer & llayer = lora->layers[il];
+
+ struct ggml_tensor * attention_norm = add_to_f32(ctx, layer.attention_norm, ggml_mul_mat(ctx, llayer.attention_norm_a, llayer.attention_norm_b));
+ struct ggml_tensor * ffn_norm = add_to_f32(ctx, layer.ffn_norm, ggml_mul_mat(ctx, llayer.ffn_norm_a, llayer.ffn_norm_b));
+ struct ggml_tensor * wq = add_to_f32(ctx, layer.wq, ggml_mul_mat(ctx, llayer.wq_a, llayer.wq_b));
+ struct ggml_tensor * wk = add_to_f32(ctx, layer.wk, ggml_mul_mat(ctx, llayer.wk_a, llayer.wk_b));
+ struct ggml_tensor * wv = add_to_f32(ctx, layer.wv, ggml_mul_mat(ctx, llayer.wv_a, llayer.wv_b));
+ struct ggml_tensor * wo = add_to_f32(ctx, layer.wo, ggml_mul_mat(ctx, llayer.wo_a, llayer.wo_b));
+ struct ggml_tensor * w1 = add_to_f32(ctx, layer.w1, ggml_mul_mat(ctx, llayer.w1_a, llayer.w1_b));
+ struct ggml_tensor * w2 = add_to_f32(ctx, layer.w2, ggml_mul_mat(ctx, llayer.w2_a, llayer.w2_b));
+ struct ggml_tensor * w3 = add_to_f32(ctx, layer.w3, ggml_mul_mat(ctx, llayer.w3_a, llayer.w3_b));
+
+ struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, rms_norm_eps); set_name(t02, "t02"); assert_shape_2d(t02, n_embd, N*n_batch);
+ struct ggml_tensor * t03 = ggml_repeat (ctx, attention_norm, t02); set_name(t03, "t03"); assert_shape_2d(t03, n_embd, N*n_batch);
+ struct ggml_tensor * t04 = ggml_mul (ctx, t03, t02); set_name(t04, "t04"); assert_shape_2d(t04, n_embd, N*n_batch);
+ struct ggml_tensor * t05 = ggml_mul_mat (ctx, wq, t04); set_name(t05, "t05"); assert_shape_2d(t05, n_embd, N*n_batch);
+ struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd_head, n_head, N, n_batch); set_name(t06, "t06"); assert_shape_4d(t06, n_embd_head, n_head, N, n_batch);
+ struct ggml_tensor * t07 = rope (t06); set_name(t07, "t07"); assert_shape_4d(t07, n_embd_head, n_head, N, n_batch);
+ struct ggml_tensor * t08 = ggml_mul_mat (ctx, wk, t04); set_name(t08, "t08"); assert_shape_2d(t08, n_embd_gqa, N*n_batch);
+ struct ggml_tensor * t09 = ggml_reshape_4d (ctx, t08, n_embd_head, n_head_kv, N, n_batch); set_name(t09, "t09"); assert_shape_4d(t09, n_embd_head, n_head_kv, N, n_batch);
+ struct ggml_tensor * t10 = rope (t09); set_name(t10, "t10"); assert_shape_4d(t10, n_embd_head, n_head_kv, N, n_batch);
+
+ struct ggml_tensor * t11;
+ if (ggml_is_quantized(wv->type)) {
+ struct ggml_tensor * t11_1 = ggml_mul_mat (ctx, wv, t04); set_name(t11_1, "t11_1"); assert_shape_2d(t11_1, n_embd_gqa, N*n_batch);
+ struct ggml_tensor * t11_2 = ggml_transpose(ctx, t11_1); set_name(t11_2, "t11_2"); assert_shape_2d(t11_2, N*n_batch, n_embd_gqa);
+ t11 = ggml_cont (ctx, t11_2); set_name(t11, "t11"); assert_shape_2d(t11, N*n_batch, n_embd_gqa);
+ } else {
+ t11 = ggml_mul_mat (ctx, t04, wv); set_name(t11, "t11"); assert_shape_2d(t11, N*n_batch, n_embd_gqa);
+ }
+
+ struct ggml_tensor * t12 = ggml_reshape_4d (ctx, t11, N, n_batch, n_embd_head, n_head_kv); set_name(t12, "t12"); assert_shape_4d(t12, N, n_batch, n_embd_head, n_head_kv);
+ struct ggml_tensor * t13 = ggml_permute (ctx, t07, 0, 2, 1, 3); set_name(t13, "t13"); assert_shape_4d(t13, n_embd_head, N, n_head, n_batch);
+ struct ggml_tensor * t14 = ggml_permute (ctx, t10, 0, 2, 1, 3); set_name(t14, "t14"); assert_shape_4d(t14, n_embd_head, N, n_head_kv, n_batch);
+ struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd_head, n_head_kv, n_batch);
+ struct ggml_tensor * t16;
+ if (enable_flash_attn) {
+ t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd_head, N, n_head, n_batch);
+ } else {
+ struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
+ struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
+ struct ggml_tensor * t16_2 = ggml_diag_mask_inf_inplace(ctx, t16_1, n_past); set_name(t16_2, "t16_2"); assert_shape_4d(t16_2, N, N, n_head, n_batch);
+ struct ggml_tensor * t16_3 = ggml_soft_max_inplace (ctx, t16_2); set_name(t16_3, "t16_3"); assert_shape_4d(t16_3, N, N, n_head, n_batch);
+ t16 = ggml_mul_mat(ctx, t15, t16_3); set_name(t16, "t16"); assert_shape_4d(t16, n_embd_head, N, n_head, n_batch);
+ }
+ struct ggml_tensor * t17 = ggml_permute (ctx, t16, 0, 2, 1, 3); set_name(t17, "t17"); assert_shape_4d(t17, n_embd_head, n_head, N, n_batch);
+ struct ggml_tensor * t18 = ggml_cont (ctx, t17); set_name(t18, "t18"); assert_shape_4d(t18, n_embd_head, n_head, N, n_batch);
+ struct ggml_tensor * t19 = ggml_reshape_2d (ctx, t18, n_embd, N*n_batch); set_name(t19, "t19"); assert_shape_2d(t19, n_embd, N*n_batch);
+ struct ggml_tensor * t20 = ggml_mul_mat (ctx, wo, t19); set_name(t20, "t20"); assert_shape_2d(t20, n_embd, N*n_batch);
+ struct ggml_tensor * t21 = ggml_add (ctx, t20, cur); set_name(t21, "t21"); assert_shape_2d(t21, n_embd, N*n_batch);
+ struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, rms_norm_eps); set_name(t22, "t22"); assert_shape_2d(t22, n_embd, N*n_batch);
+ struct ggml_tensor * t23 = ggml_repeat (ctx, ffn_norm, t22); set_name(t23, "t23"); assert_shape_2d(t23, n_embd, N*n_batch);
+ struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); set_name(t24, "t24"); assert_shape_2d(t24, n_embd, N*n_batch);
+ struct ggml_tensor * t25 = ggml_mul_mat (ctx, w3, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch);
+ struct ggml_tensor * t26 = ggml_mul_mat (ctx, w1, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch);
+ struct ggml_tensor * t27 = ggml_silu (ctx, t26); set_name(t27, "t27"); assert_shape_2d(t27, n_ff, N*n_batch);
+ struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); set_name(t28, "t28"); assert_shape_2d(t28, n_ff, N*n_batch);
+ struct ggml_tensor * t29 = ggml_mul_mat (ctx, w2, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
+ struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch);
+ cur = t30;
+ if (enable_checkpointing) {
+ checkpoints.push_back(cur);
+ }
+ }
+ struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, rms_norm_eps); set_name(t31, "t31"); assert_shape_2d(t31, n_embd, N*n_batch);
+ struct ggml_tensor * t32 = ggml_repeat (ctx, norm, t31); set_name(t32, "t32"); assert_shape_2d(t32, n_embd, N*n_batch);
+ struct ggml_tensor * t33 = ggml_mul (ctx, t32, t31); set_name(t33, "t33"); assert_shape_2d(t33, n_embd, N*n_batch);
+ struct ggml_tensor * t34 = ggml_mul_mat (ctx, output, t33); set_name(t34, "t34"); assert_shape_2d(t34, n_vocab, N*n_batch);
+ struct ggml_tensor * t35 = ggml_reshape_3d (ctx, t34, n_vocab, N, n_batch); set_name(t35, "t35"); assert_shape_3d(t35, n_vocab, N, n_batch);
+ struct ggml_tensor * t36 = ggml_cross_entropy_loss(ctx, t35, targets); set_name(t36, "t36"); assert_shape_1d(t36, 1);
+
+ if (enable_checkpointing) {
+ checkpoints.push_back(t31);
+ checkpoints.push_back(t32);
+ checkpoints.push_back(t33);
+ checkpoints.push_back(t34);
+ checkpoints.push_back(t35);
+ checkpoints.push_back(t36);
+ }
+
+ ggml_build_forward_expand(gf, t36);
+
+ if (enable_checkpointing) {
+ ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
+ } else {
+ *gb = *gf;
+ ggml_build_backward_expand(ctx, gf, gb, true);
+ }
+
+ GGML_ASSERT(alloc != NULL);
+
+ // make sure some tensors are not reallocated by inserting new temporary nodes depending on them
+ int n_leafs_before = gb->n_leafs;
+ int n_nodes_before = gb->n_nodes;
+ struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
+ // output tensors
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
+ // input gradient
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
+ GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
+ ggml_allocr_alloc(alloc, t36->grad);
+
+ // make sure base model tensors data cannot be used in viewable operations
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, one));
+ for (int il = 0; il < n_layer; ++il) {
+ struct my_llama_layer & layer = model->layers[il];
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one));
+ }
+
+ // allocating checkpoints in one block to reduce memory fragmentation
+ // note: they will be freed in reverse order
+ for (unsigned int i = 0; i < checkpoints.size(); ++i) {
+ if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
+ ggml_allocr_alloc(alloc, checkpoints[i]);
+ }
+ }
+
+ ggml_allocr_alloc_graph(alloc, gb);
+
+ // remove the additional nodes and leafs
+ for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
+ gb->leafs[i] = NULL;
+ }
+ for (int i = n_nodes_before; i < gb->n_nodes; ++i) {
+ gb->nodes[i] = NULL;
+ }
+ gb->n_leafs = n_leafs_before;
+ gb->n_nodes = n_nodes_before;
+
+ *logits = t35;
+ return t36;
+}
+
+static void load_llama_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora) {
+ // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
+
+ std::string arch;
+
+ std::vector<char> keybuf;
+ keybuf.resize(512);
+
+ GGUF_GET_KEY(fctx, arch, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_GENERAL_ARCHITECTURE);
+ GGML_ASSERT(arch == "llama");
+
+ uint32_t ftype_u;
+ GGUF_GET_KEY(fctx, ftype_u, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_GENERAL_FILE_TYPE);
+ GGML_ASSERT((enum llama_ftype) ftype_u == LLAMA_FTYPE_ALL_F32);
+
+ struct my_llama_hparams hparams;
+ load_model_hparams_gguf(fctx, &hparams, arch.c_str());
+
+ // parameters that define tensor shapes must match
+ GGML_ASSERT(hparams.n_embd == model->hparams.n_embd);
+ GGML_ASSERT(hparams.n_ff == model->hparams.n_ff);
+ GGML_ASSERT(hparams.n_head == model->hparams.n_head);
+ GGML_ASSERT(hparams.n_head_kv == model->hparams.n_head_kv);
+ GGML_ASSERT(hparams.n_layer == model->hparams.n_layer);
+
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_tok_embeddings, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD);
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_norm, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM);
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_output, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_OUTPUT);
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_attention_norm, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_NORM);
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_wq, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_Q);
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_wk, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_K);
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_wv, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_V);
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_wo, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_OUT);
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_ffn_norm, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_NORM);
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_w1, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_GATE);
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_w2, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_DOWN);
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_w3, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_UP);
+
+ init_lora(model, lora);
+
+ copy_tensor_by_name(lora->tok_embeddings_a, f_ggml_ctx, ggml_get_name(lora->tok_embeddings_a));
+ copy_tensor_by_name(lora->tok_embeddings_b, f_ggml_ctx, ggml_get_name(lora->tok_embeddings_b));
+ copy_tensor_by_name(lora->norm_a, f_ggml_ctx, ggml_get_name(lora->norm_a));
+ copy_tensor_by_name(lora->norm_b, f_ggml_ctx, ggml_get_name(lora->norm_b));
+ copy_tensor_by_name(lora->output_a, f_ggml_ctx, ggml_get_name(lora->output_a));
+ copy_tensor_by_name(lora->output_b, f_ggml_ctx, ggml_get_name(lora->output_b));
+
+ for (uint32_t i = 0; i < lora->layers.size(); ++i) {
+ auto & layer = lora->layers[i];
+ copy_tensor_by_name(layer.attention_norm_a, f_ggml_ctx, ggml_get_name(layer.attention_norm_a));
+ copy_tensor_by_name(layer.attention_norm_b, f_ggml_ctx, ggml_get_name(layer.attention_norm_b));
+ copy_tensor_by_name(layer.wq_a, f_ggml_ctx, ggml_get_name(layer.wq_a));
+ copy_tensor_by_name(layer.wq_b, f_ggml_ctx, ggml_get_name(layer.wq_b));
+ copy_tensor_by_name(layer.wk_a, f_ggml_ctx, ggml_get_name(layer.wk_a));
+ copy_tensor_by_name(layer.wk_b, f_ggml_ctx, ggml_get_name(layer.wk_b));
+ copy_tensor_by_name(layer.wv_a, f_ggml_ctx, ggml_get_name(layer.wv_a));
+ copy_tensor_by_name(layer.wv_b, f_ggml_ctx, ggml_get_name(layer.wv_b));
+ copy_tensor_by_name(layer.wo_a, f_ggml_ctx, ggml_get_name(layer.wo_a));
+ copy_tensor_by_name(layer.wo_b, f_ggml_ctx, ggml_get_name(layer.wo_b));
+ copy_tensor_by_name(layer.ffn_norm_a, f_ggml_ctx, ggml_get_name(layer.ffn_norm_a));
+ copy_tensor_by_name(layer.ffn_norm_b, f_ggml_ctx, ggml_get_name(layer.ffn_norm_b));
+ copy_tensor_by_name(layer.w1_a, f_ggml_ctx, ggml_get_name(layer.w1_a));
+ copy_tensor_by_name(layer.w1_b, f_ggml_ctx, ggml_get_name(layer.w1_b));
+ copy_tensor_by_name(layer.w2_a, f_ggml_ctx, ggml_get_name(layer.w2_a));
+ copy_tensor_by_name(layer.w2_b, f_ggml_ctx, ggml_get_name(layer.w2_b));
+ copy_tensor_by_name(layer.w3_a, f_ggml_ctx, ggml_get_name(layer.w3_a));
+ copy_tensor_by_name(layer.w3_b, f_ggml_ctx, ggml_get_name(layer.w3_b));
+ }
+}
+
+static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora) {
+ const char * arch = "llama";
+ enum llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
+
+ std::vector<char> keybuf;
+ keybuf.resize(512);
+ auto kv = [arch, &keybuf](const char * key) -> const char * {
+ snprintf(keybuf.data(), keybuf.size(), key, arch);
+ return keybuf.data();
+ };
+
+ gguf_set_val_str(fctx, LLM_KV_GENERAL_ARCHITECTURE, arch);
+ gguf_set_val_u32(fctx, LLM_KV_GENERAL_FILE_TYPE, ftype);
+
+ gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH), model->hparams.n_ctx);
+ gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH), model->hparams.n_embd);
+ gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH), model->hparams.n_ff);
+ gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head);
+ gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV), model->hparams.n_head_kv);
+ gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT), model->hparams.n_layer);
+ gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT), model->hparams.n_embd_head());
+ gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), model->hparams.f_norm_rms_eps);
+ gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_FREQ_BASE), model->hparams.rope_freq_base);
+ gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_SCALE_LINEAR), model->hparams.rope_freq_scale);
+
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD, lora->hparams.n_rank_tok_embeddings);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM, lora->hparams.n_rank_norm);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_OUTPUT, lora->hparams.n_rank_output);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_NORM, lora->hparams.n_rank_attention_norm);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_Q, lora->hparams.n_rank_wq);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_K, lora->hparams.n_rank_wk);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_V, lora->hparams.n_rank_wv);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_OUT, lora->hparams.n_rank_wo);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_NORM, lora->hparams.n_rank_ffn_norm);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_GATE, lora->hparams.n_rank_w1);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_DOWN, lora->hparams.n_rank_w2);
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_UP, lora->hparams.n_rank_w3);
+
+ gguf_add_tensor(fctx, lora->tok_embeddings_a);
+ gguf_add_tensor(fctx, lora->tok_embeddings_b);
+ gguf_add_tensor(fctx, lora->norm_a);
+ gguf_add_tensor(fctx, lora->norm_b);
+ gguf_add_tensor(fctx, lora->output_a);
+ gguf_add_tensor(fctx, lora->output_b);
+
+ for (uint32_t i = 0; i < lora->layers.size(); ++i) {
+ auto & layer = lora->layers[i];
+
+ gguf_add_tensor(fctx, layer.attention_norm_a);
+ gguf_add_tensor(fctx, layer.attention_norm_b);
+ gguf_add_tensor(fctx, layer.wq_a);
+ gguf_add_tensor(fctx, layer.wq_b);
+ gguf_add_tensor(fctx, layer.wk_a);
+ gguf_add_tensor(fctx, layer.wk_b);
+ gguf_add_tensor(fctx, layer.wv_a);
+ gguf_add_tensor(fctx, layer.wv_b);
+ gguf_add_tensor(fctx, layer.wo_a);
+ gguf_add_tensor(fctx, layer.wo_b);
+ gguf_add_tensor(fctx, layer.ffn_norm_a);
+ gguf_add_tensor(fctx, layer.ffn_norm_b);
+ gguf_add_tensor(fctx, layer.w1_a);
+ gguf_add_tensor(fctx, layer.w1_b);
+ gguf_add_tensor(fctx, layer.w2_a);
+ gguf_add_tensor(fctx, layer.w2_b);
+ gguf_add_tensor(fctx, layer.w3_a);
+ gguf_add_tensor(fctx, layer.w3_b);
+ }
+}
+
+static void load_checkpoint_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
+ std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA;
+ GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
+ GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
+
+ load_train_state_gguf(fctx, f_ggml_ctx, train);
+ load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora);
+}
+
+static void save_checkpoint_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
+ gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
+ save_llama_lora_gguf(fctx, model, lora);
+ save_train_state_gguf(fctx, train);
+}
+
+static bool load_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
+ struct ggml_context * f_ggml_ctx;
+ struct gguf_init_params params;
+ params.no_alloc = false;
+ params.ctx = &f_ggml_ctx;
+ struct gguf_context * fctx = gguf_init_from_file(filename, params);
+ if (fctx == NULL) {
+ return false;
+ }
+
+ load_checkpoint_lora_gguf(fctx, f_ggml_ctx, model, lora, train);
+
+ gguf_free(fctx);
+ return true;
+}
+
+static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
+ printf("%s: saving to %s\n", __func__, filename);
+ struct gguf_context * fctx = gguf_init_empty();
+
+ save_checkpoint_lora_gguf(fctx, model, lora, train);
+
+ // write file
+ const bool only_meta = false;
+ gguf_write_to_file(fctx, filename, only_meta);
+ gguf_free(fctx);
+}
+
+struct llama_file {
+ // use FILE * so we don't have to re-open the file to mmap
+ FILE * fp;
+ size_t size;
+
+ llama_file(const char * fname, const char * mode) {
+ fp = std::fopen(fname, mode);
+ if (fp == NULL) {
+ size = 0;
+ } else {
+ seek(0, SEEK_END);
+ size = tell();
+ seek(0, SEEK_SET);
+ }
+ }
+
+ size_t tell() const {
+#ifdef _WIN32
+ __int64 ret = _ftelli64(fp);
+#else
+ long ret = std::ftell(fp);
+#endif
+ GGML_ASSERT(ret != -1); // this really shouldn't fail
+ return (size_t) ret;
+ }
+
+ void seek(size_t offset, int whence) {
+#ifdef _WIN32
+ int ret = _fseeki64(fp, (__int64) offset, whence);
+#else
+ int ret = std::fseek(fp, (long) offset, whence);
+#endif
+ GGML_ASSERT(ret == 0); // same
+ }
+
+ void read_raw(void * ptr, size_t size) {
+ if (size == 0) {
+ return;
+ }
+ errno = 0;
+ std::size_t ret = std::fread(ptr, size, 1, fp);
+ if (ferror(fp)) {
+ die_fmt("read error: %s", strerror(errno));
+ }
+ if (ret != 1) {
+ die("unexpectedly reached end of file");
+ }
+ }
+
+ std::uint32_t read_u32() {
+ std::uint32_t ret;
+ read_raw(&ret, sizeof(ret));
+ return ret;
+ }
+
+ std::string read_string(std::uint32_t len) {
+ std::vector<char> chars(len);
+ read_raw(chars.data(), len);
+ return std::string(chars.data(), len);
+ }
+
+ void write_raw(const void * ptr, size_t size) {
+ if (size == 0) {
+ return;
+ }
+ errno = 0;
+ size_t ret = std::fwrite(ptr, size, 1, fp);
+ if (ret != 1) {
+ die_fmt("write error: %s", strerror(errno));
+ }
+ }
+
+ void write_u32(std::uint32_t val) {
+ write_raw(&val, sizeof(val));
+ }
+
+ ~llama_file() {
+ if (fp) {
+ std::fclose(fp);
+ }
+ }
+};
+
+static void write_tensor(struct llama_file * file, struct ggml_tensor * tensor, const char * name) {
+ if (tensor == NULL) {
+ file->write_u32(0);
+ file->write_u32(0);
+ file->write_u32(GGML_TYPE_F32);
+ file->seek((0-file->tell()) & 31, SEEK_CUR);
+ return;
+ }
+ if (name == NULL) {
+ name = ggml_get_name(tensor);
+ }
+ uint32_t name_len = strlen(name);
+ uint32_t nd = tensor->n_dims;
+ uint32_t ne[4] = { (uint32_t)tensor->ne[0],
+ (uint32_t)tensor->ne[1],
+ (uint32_t)tensor->ne[2],
+ (uint32_t)tensor->ne[3] };
+ file->write_u32(nd);
+ file->write_u32(name_len);
+ file->write_u32(tensor->type);
+ file->write_raw(ne, sizeof(ne[0]) * nd);
+ file->write_raw(name, name_len);
+ file->seek((0-file->tell()) & 31, SEEK_CUR);
+ file->write_raw(tensor->data, ggml_nbytes(tensor));
+}
+
+static void save_as_llama_lora(const char * filename, struct my_llama_lora * lora) {
+ printf("%s: saving to %s\n", __func__, filename);
+ struct llama_file file(filename, "wb");
+ if (file.fp == NULL) {
+ return;
+ }
+
+ std::vector<char> tn_buf;
+ tn_buf.resize(GGML_MAX_NAME);
+
+ auto tn = [&tn_buf](const char * key, const char * suffix) -> const char * {
+ snprintf(tn_buf.data(), tn_buf.size(), "%s%s", key, suffix);
+ return tn_buf.data();
+ };
+
+ auto tni = [&tn_buf](const char * key, int bid, const char * suffix) -> const char * {
+ snprintf(tn_buf.data(), tn_buf.size(), key, bid);
+ std::string s = tn_buf.data();
+ snprintf(tn_buf.data(), tn_buf.size(), "%s%s", s.c_str(), suffix);
+ return tn_buf.data();
+ };
+
+ uint32_t LLAMA_FILE_MAGIC_LORA = 0x67676C61; // 'ggla'
+ // write_magic
+ file.write_u32(LLAMA_FILE_MAGIC_LORA); // magic
+ file.write_u32(1); // version
+ // write_hparams
+ file.write_u32(lora->hparams.lora_r);
+ file.write_u32(lora->hparams.lora_alpha);
+ // write tensors
+ write_tensor(&file, lora->tok_embeddings_a, tn(LLM_TENSOR_TOKEN_EMBD, ".weight.loraA"));
+ write_tensor(&file, lora->tok_embeddings_b, tn(LLM_TENSOR_TOKEN_EMBD, ".weight.loraB"));
+ write_tensor(&file, lora->norm_a, tn(LLM_TENSOR_OUTPUT_NORM, ".weight.loraA"));
+ write_tensor(&file, lora->norm_b, tn(LLM_TENSOR_OUTPUT_NORM, ".weight.loraB"));
+ write_tensor(&file, lora->output_a, tn(LLM_TENSOR_OUTPUT, ".weight.loraA"));
+ write_tensor(&file, lora->output_b, tn(LLM_TENSOR_OUTPUT, ".weight.loraB"));
+ for (uint32_t i = 0; i < lora->layers.size(); ++i) {
+ auto & layer = lora->layers[i];
+ write_tensor(&file, layer.attention_norm_a, tni(LLM_TENSOR_ATTN_NORM, i, ".weight.loraA"));
+ write_tensor(&file, layer.attention_norm_b, tni(LLM_TENSOR_ATTN_NORM, i, ".weight.loraB"));
+ write_tensor(&file, layer.wq_a, tni(LLM_TENSOR_ATTN_Q, i, ".weight.loraA"));
+ write_tensor(&file, layer.wq_b, tni(LLM_TENSOR_ATTN_Q, i, ".weight.loraB"));
+ write_tensor(&file, layer.wk_a, tni(LLM_TENSOR_ATTN_K, i, ".weight.loraA"));
+ write_tensor(&file, layer.wk_b, tni(LLM_TENSOR_ATTN_K, i, ".weight.loraB"));
+ write_tensor(&file, layer.wv_a, tni(LLM_TENSOR_ATTN_V, i, ".weight.loraA"));
+ write_tensor(&file, layer.wv_b, tni(LLM_TENSOR_ATTN_V, i, ".weight.loraB"));
+ write_tensor(&file, layer.wo_a, tni(LLM_TENSOR_ATTN_OUT, i, ".weight.loraA"));
+ write_tensor(&file, layer.wo_b, tni(LLM_TENSOR_ATTN_OUT, i, ".weight.loraB"));
+ write_tensor(&file, layer.ffn_norm_a, tni(LLM_TENSOR_FFN_NORM, i, ".weight.loraA"));
+ write_tensor(&file, layer.ffn_norm_b, tni(LLM_TENSOR_FFN_NORM, i, ".weight.loraB"));
+ write_tensor(&file, layer.w1_a, tni(LLM_TENSOR_FFN_GATE, i, ".weight.loraA"));
+ write_tensor(&file, layer.w1_b, tni(LLM_TENSOR_FFN_GATE, i, ".weight.loraB"));
+ write_tensor(&file, layer.w2_a, tni(LLM_TENSOR_FFN_DOWN, i, ".weight.loraA"));
+ write_tensor(&file, layer.w2_b, tni(LLM_TENSOR_FFN_DOWN, i, ".weight.loraB"));
+ write_tensor(&file, layer.w3_a, tni(LLM_TENSOR_FFN_UP, i, ".weight.loraA"));
+ write_tensor(&file, layer.w3_b, tni(LLM_TENSOR_FFN_UP, i, ".weight.loraB"));
+ }
+}
+
+struct train_params {
+ struct train_params_common common;
+
+ const char * fn_model_base;
+ const char * fn_lora_out;
+
+ bool only_write_lora;
+
+ float f_norm_rms_eps;
+ float rope_freq_base;
+ float rope_freq_scale;
+
+ bool custom_f_norm_rms_eps;
+ bool custom_rope_freq_base;
+ bool custom_rope_freq_scale;
+
+ int32_t lora_r;
+ int32_t lora_alpha;
+ bool custom_lora_alpha;
+
+ uint32_t n_rank_attention_norm;
+ uint32_t n_rank_wq;
+ uint32_t n_rank_wk;
+ uint32_t n_rank_wv;
+ uint32_t n_rank_wo;
+ uint32_t n_rank_ffn_norm;
+ uint32_t n_rank_w1;
+ uint32_t n_rank_w2;
+ uint32_t n_rank_w3;
+ uint32_t n_rank_tok_embeddings;
+ uint32_t n_rank_norm;
+ uint32_t n_rank_output;
+
+ bool custom_n_rank_attention_norm;
+ bool custom_n_rank_wq;
+ bool custom_n_rank_wk;
+ bool custom_n_rank_wv;
+ bool custom_n_rank_wo;
+ bool custom_n_rank_ffn_norm;
+ bool custom_n_rank_w1;
+ bool custom_n_rank_w2;
+ bool custom_n_rank_w3;
+ bool custom_n_rank_tok_embeddings;
+ bool custom_n_rank_norm;
+ bool custom_n_rank_output;
+};
+
+static struct train_params get_default_train_params() {
+ struct train_params params;
+ params.common = get_default_train_params_common();
+ params.fn_model_base = "";
+ params.fn_lora_out = "ggml-lora-ITERATION-f32.gguf";
+
+ params.only_write_lora = false;
+
+ params.f_norm_rms_eps = 1e-5f;
+ params.rope_freq_base = 10000.0f;
+ params.rope_freq_scale = 1.0f;
+
+ params.custom_f_norm_rms_eps = false;
+ params.custom_rope_freq_base = false;
+ params.custom_rope_freq_scale = false;
+
+ params.lora_r = 4;
+ params.lora_alpha = 4;
+ params.custom_lora_alpha = false;
+
+ params.n_rank_attention_norm = 1;
+ params.n_rank_wq = 4;
+ params.n_rank_wk = 4;
+ params.n_rank_wv = 4;
+ params.n_rank_wo = 4;
+ params.n_rank_ffn_norm = 1;
+ params.n_rank_w1 = 4;
+ params.n_rank_w2 = 4;
+ params.n_rank_w3 = 4;
+ params.n_rank_tok_embeddings = 4;
+ params.n_rank_norm = 1;
+ params.n_rank_output = 4;
+
+ params.custom_n_rank_attention_norm = false;
+ params.custom_n_rank_wq = false;
+ params.custom_n_rank_wk = false;
+ params.custom_n_rank_wv = false;
+ params.custom_n_rank_wo = false;
+ params.custom_n_rank_ffn_norm = false;
+ params.custom_n_rank_w1 = false;
+ params.custom_n_rank_w2 = false;
+ params.custom_n_rank_w3 = false;
+ params.custom_n_rank_tok_embeddings = false;
+ params.custom_n_rank_norm = false;
+ params.custom_n_rank_output = false;
+
+ return params;
+}
+
+static void train_print_usage(int argc, char ** argv, const struct train_params * params) {
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
+ fprintf(stderr, "\n");
+ fprintf(stderr, "options:\n");
+ fprintf(stderr, " -h, --help show this help message and exit\n");
+
+ fprintf(stderr, " --model-base FNAME model path from which to load base model (default '%s')\n", params->fn_model_base);
+ fprintf(stderr, " --lora-out FNAME path to save llama lora (default '%s')\n", params->fn_lora_out);
+ fprintf(stderr, " --only-write-lora only save llama lora, don't do any training. use this if you only want to convert a checkpoint to a lora adapter.\n");
+ fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
+ fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
+ fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
+ fprintf(stderr, " --lora-alpha N LORA alpha : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_alpha);
+ fprintf(stderr, " --lora-r N LORA r: default rank. Also specifies resulting scaling together with lora-alpha. (default %d)\n", params->lora_r);
+ fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
+ fprintf(stderr, " --rank-ffn-norm N LORA rank for feed-forward norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
+ fprintf(stderr, " --rank-out-norm N LORA rank for output norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
+ fprintf(stderr, " --rank-tok-embd N LORA rank for token embeddings tensor, overrides default rank.\n");
+ fprintf(stderr, " --rank-out N LORA rank for output tensor, overrides default rank.\n");
+ fprintf(stderr, " --rank-wq N LORA rank for wq tensor, overrides default rank.\n");
+ fprintf(stderr, " --rank-wk N LORA rank for wk tensor, overrides default rank.\n");
+ fprintf(stderr, " --rank-wv N LORA rank for wv tensor, overrides default rank.\n");
+ fprintf(stderr, " --rank-wo N LORA rank for wo tensor, overrides default rank.\n");
+ fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor, overrides default rank.\n");
+ fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor, overrides default rank.\n");
+ fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor, overrides default rank.\n");
+
+ print_common_train_usage(argc, argv, ¶ms->common);
+}
+
+static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
+ bool invalid_param = false;
+ std::string arg;
+ struct train_params default_params = get_default_train_params();
+ const std::string arg_prefix = "--";
+
+ for (int i = 1; i < argc; i++) {
+ arg = argv[i];
+ if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
+ std::replace(arg.begin(), arg.end(), '_', '-');
+ }
+
+ if (consume_common_train_arg(argc, argv, &i, ¶ms->common, &invalid_param)) {
+ if (invalid_param) {
+ break;
+ } else if (params->common.print_usage) {
+ train_print_usage(argc, argv, &default_params);
+ exit(0);
+ }
+ } else if (arg == "--model-base") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->fn_model_base = argv[i];
+ } else if (arg == "--lora-out") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->fn_lora_out = argv[i];
+ } else if (arg == "--only-write-lora") {
+ params->only_write_lora = true;
+ } else if (arg == "--norm-rms-eps") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->f_norm_rms_eps = std::stof(argv[i]);
+ params->custom_f_norm_rms_eps = true;
+ } else if (arg == "--rope-freq-base") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->rope_freq_base = std::stof(argv[i]);
+ params->custom_rope_freq_base = true;
+ } else if (arg == "--rope-freq-scale") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->rope_freq_scale = std::stof(argv[i]);
+ params->custom_rope_freq_scale = true;
+ } else if (arg == "--lora-alpha") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->lora_alpha = std::stoi(argv[i]);
+ params->custom_lora_alpha = true;
+ } else if (arg == "--lora-r") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->lora_r = std::stoi(argv[i]);
+ } else if (arg == "--rank-att-norm") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_rank_attention_norm = std::stoi(argv[i]);
+ params->custom_n_rank_attention_norm = true;
+ } else if (arg == "--rank-ffn-norm") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_rank_ffn_norm = std::stoi(argv[i]);
+ params->custom_n_rank_ffn_norm = true;
+ } else if (arg == "--rank-out-norm") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_rank_norm = std::stoi(argv[i]);
+ params->custom_n_rank_norm = true;
+ } else if (arg == "--rank-tok-embd") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_rank_tok_embeddings = std::stoi(argv[i]);
+ params->custom_n_rank_tok_embeddings = true;
+ } else if (arg == "--rank-out") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_rank_output = std::stoi(argv[i]);
+ params->custom_n_rank_output = true;
+ } else if (arg == "--rank-wq") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_rank_wq = std::stoi(argv[i]);
+ params->custom_n_rank_wq = true;
+ } else if (arg == "--rank-wk") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_rank_wk = std::stoi(argv[i]);
+ params->custom_n_rank_wk = true;
+ } else if (arg == "--rank-wv") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_rank_wv = std::stoi(argv[i]);
+ params->custom_n_rank_wv = true;
+ } else if (arg == "--rank-wo") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_rank_wo = std::stoi(argv[i]);
+ params->custom_n_rank_wo = true;
+ } else if (arg == "--rank-w1") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_rank_w1 = std::stoi(argv[i]);
+ params->custom_n_rank_w1 = true;
+ } else if (arg == "--rank-w2") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_rank_w2 = std::stoi(argv[i]);
+ params->custom_n_rank_w2 = true;
+ } else if (arg == "--rank-w3") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params->n_rank_w3 = std::stoi(argv[i]);
+ params->custom_n_rank_w3 = true;
+ } else {
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+ train_print_usage(argc, argv, &default_params);
+ exit(1);
+ }
+ }
+ if (invalid_param) {
+ fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
+ train_print_usage(argc, argv, &default_params);
+ exit(1);
+ }
+ finish_processing_train_args(¶ms->common);
+ return true;
+}
+
+struct save_train_files_data {
+ const char * fn_checkpoint_out;
+ const char * fn_lora_out;
+ const char * pattern_fn_it;
+ const char * fn_latest;
+ struct my_llama_model * model;
+ struct my_llama_lora * lora;
+};
+
+static void save_train_files(void * vdata, struct train_state * train) {
+ struct save_train_files_data * data = (struct save_train_files_data *) vdata;
+
+ int64_t iter = train->opt->iter;
+
+ if (strlen(data->fn_checkpoint_out) > 0) {
+ save_checkpoint_lora_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->model, data->lora, train);
+ save_checkpoint_lora_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->model, data->lora, train);
+ }
+ if (strlen(data->fn_lora_out) > 0) {
+ save_as_llama_lora(get_train_filename(data->fn_lora_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->lora);
+ save_as_llama_lora(get_train_filename(data->fn_lora_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->lora);
+ }
+}
+
+static int64_t get_parameter_count(struct my_llama_lora* lora) {
+ int64_t nx = 0;
+ nx += ggml_nelements(lora->tok_embeddings_a);
+ nx += ggml_nelements(lora->tok_embeddings_b);
+ nx += ggml_nelements(lora->norm_a);
+ nx += ggml_nelements(lora->norm_b);
+ nx += ggml_nelements(lora->output_a);
+ nx += ggml_nelements(lora->output_b);
+
+ for (uint32_t i = 0; i < lora->layers.size(); ++i) {
+ auto & layer = lora->layers[i];
+ nx += ggml_nelements(layer.attention_norm_a);
+ nx += ggml_nelements(layer.attention_norm_b);
+ nx += ggml_nelements(layer.wq_a);
+ nx += ggml_nelements(layer.wq_b);
+ nx += ggml_nelements(layer.wk_a);
+ nx += ggml_nelements(layer.wk_b);
+ nx += ggml_nelements(layer.wv_a);
+ nx += ggml_nelements(layer.wv_b);
+ nx += ggml_nelements(layer.wo_a);
+ nx += ggml_nelements(layer.wo_b);
+ nx += ggml_nelements(layer.ffn_norm_a);
+ nx += ggml_nelements(layer.ffn_norm_b);
+ nx += ggml_nelements(layer.w1_a);
+ nx += ggml_nelements(layer.w1_b);
+ nx += ggml_nelements(layer.w2_a);
+ nx += ggml_nelements(layer.w2_b);
+ nx += ggml_nelements(layer.w3_a);
+ nx += ggml_nelements(layer.w3_b);
+ }
+ return nx;
+}
+
+int main(int argc, char ** argv) {
+ struct train_params params = get_default_train_params();
+
+ if (!train_params_parse(argc, argv, ¶ms)) {
+ return 1;
+ }
+
+ if (params.common.seed == LLAMA_DEFAULT_SEED) {
+ params.common.seed = time(NULL);
+ }
+ printf("%s: seed: %u\n", __func__, params.common.seed);
+ srand(params.common.seed);
+
+ struct llama_context_params llama_params = llama_context_default_params();
+ llama_params.vocab_only = false;
+
+ printf("%s: model base = '%s'\n", __func__, params.fn_model_base);
+ struct llama_model * lmodel = llama_load_model_from_file(params.fn_model_base, llama_params);
+ struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
+
+ struct my_llama_model model;
+ init_model(lmodel, &model, params.fn_model_base, params.common.n_ctx);
+
+ struct my_llama_lora lora;
+
+ struct train_state * train = init_train_state();
+ struct ggml_opt_context * opt = train->opt;
+
+ // set params from command line
+ if (params.custom_f_norm_rms_eps) {
+ model.hparams.f_norm_rms_eps = params.f_norm_rms_eps;
+ }
+ if (params.custom_rope_freq_base) {
+ model.hparams.rope_freq_base = params.rope_freq_base;
+ }
+ if (params.custom_rope_freq_scale) {
+ model.hparams.rope_freq_scale = params.rope_freq_scale;
+ }
+ lora.hparams.lora_r = params.lora_r;
+ lora.hparams.lora_alpha = params.custom_lora_alpha ? params.lora_alpha : params.lora_r;
+ uint32_t n_rank_attention_norm = params.custom_n_rank_attention_norm ? params.n_rank_attention_norm : 1;
+ uint32_t n_rank_wq = params.custom_n_rank_wq ? params.n_rank_wq : params.lora_r;
+ uint32_t n_rank_wk = params.custom_n_rank_wk ? params.n_rank_wk : params.lora_r;
+ uint32_t n_rank_wv = params.custom_n_rank_wv ? params.n_rank_wv : params.lora_r;
+ uint32_t n_rank_wo = params.custom_n_rank_wo ? params.n_rank_wo : params.lora_r;
+ uint32_t n_rank_ffn_norm = params.custom_n_rank_ffn_norm ? params.n_rank_ffn_norm : 1;
+ uint32_t n_rank_w1 = params.custom_n_rank_w1 ? params.n_rank_w1 : params.lora_r;
+ uint32_t n_rank_w2 = params.custom_n_rank_w2 ? params.n_rank_w2 : params.lora_r;
+ uint32_t n_rank_w3 = params.custom_n_rank_w3 ? params.n_rank_w3 : params.lora_r;
+ uint32_t n_rank_tok_embeddings = params.custom_n_rank_tok_embeddings ? params.n_rank_tok_embeddings : params.lora_r;
+ uint32_t n_rank_norm = params.custom_n_rank_norm ? params.n_rank_norm : 1;
+ uint32_t n_rank_output = params.custom_n_rank_output ? params.n_rank_output : params.lora_r;
+ lora.hparams.n_rank_attention_norm = n_rank_attention_norm;
+ lora.hparams.n_rank_wq = n_rank_wq;
+ lora.hparams.n_rank_wk = n_rank_wk;
+ lora.hparams.n_rank_wv = n_rank_wv;
+ lora.hparams.n_rank_wo = n_rank_wo;
+ lora.hparams.n_rank_ffn_norm = n_rank_ffn_norm;
+ lora.hparams.n_rank_w1 = n_rank_w1;
+ lora.hparams.n_rank_w2 = n_rank_w2;
+ lora.hparams.n_rank_w3 = n_rank_w3;
+ lora.hparams.n_rank_tok_embeddings = n_rank_tok_embeddings;
+ lora.hparams.n_rank_norm = n_rank_norm;
+ lora.hparams.n_rank_output = n_rank_output;
+
+ // set opt params from command line
+ opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
+ opt->params.print_forward_graph = false;
+ opt->params.print_backward_graph = false;
+ opt->params.n_threads = params.common.n_threads;
+ opt->params.past = params.common.opt_past;
+ opt->params.delta = params.common.opt_delta;
+ opt->params.max_no_improvement = params.common.opt_max_no_improvement;
+ opt->params.n_gradient_accumulation = params.common.n_gradient_accumulation;
+ opt->params.adam.n_iter = params.common.adam_n_iter;
+ opt->params.adam.sched = 1.0f;
+ opt->params.adam.alpha = params.common.adam_alpha;
+ opt->params.adam.decay = params.common.adam_decay;
+ opt->params.adam.decay_min_ndim = params.common.adam_decay_min_ndim;
+ opt->params.adam.beta1 = params.common.adam_beta1;
+ opt->params.adam.beta2 = params.common.adam_beta2;
+ opt->params.adam.gclip = params.common.adam_gclip;
+ opt->params.adam.eps_f = params.common.adam_eps_f;
+
+ ggml_allocr * alloc = NULL;
+
+ printf("%s: init model\n", __func__);
+ bool existed = load_checkpoint_lora_file(params.common.fn_checkpoint_in, &model, &lora, train);
+
+ if (existed) {
+ // overwrite last n_ctx with user provided n_ctx
+ if (params.common.custom_n_ctx) {
+ model.hparams.n_ctx = params.common.n_ctx;
+ }
+
+ const bool opt_param_count_changed = (
+ (lora.hparams.n_rank_attention_norm != n_rank_attention_norm)
+ || (lora.hparams.n_rank_wq != n_rank_wq)
+ || (lora.hparams.n_rank_wk != n_rank_wk)
+ || (lora.hparams.n_rank_wv != n_rank_wv)
+ || (lora.hparams.n_rank_wo != n_rank_wo)
+ || (lora.hparams.n_rank_ffn_norm != n_rank_ffn_norm)
+ || (lora.hparams.n_rank_w1 != n_rank_w1)
+ || (lora.hparams.n_rank_w2 != n_rank_w2)
+ || (lora.hparams.n_rank_w3 != n_rank_w3)
+ || (lora.hparams.n_rank_tok_embeddings != n_rank_tok_embeddings)
+ || (lora.hparams.n_rank_norm != n_rank_norm)
+ || (lora.hparams.n_rank_output != n_rank_output)
+ );
+
+ const bool opt_past_changed = opt->params.past != params.common.opt_past;
+
+ if (opt_param_count_changed) {
+ print_lora_params(&lora.hparams);
+ die("Provided rank differs from checkpoint file. To use different rank start finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting.");
+ // need to discard previous optimizer gradient statistics and opt_init with new shapes
+ // TODO
+ }
+ if (opt_past_changed) {
+ die("Optimizer parameter '--opt-past N' differs from checkpoint file. To use different value finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting");
+ // need to discard previous optimizer past function value statistics and opt_init with new shapes
+ // TODO
+ }
+ } else { // existed == false
+ init_lora(&model, &lora);
+ randomize_lora(&lora, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f);
+ if (!params.only_write_lora) {
+ ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&lora));
+ }
+ }
+ opt->iter = train->train_its;
+
+ print_params(&model.hparams);
+ print_lora_params(&lora.hparams);
+ printf("%s: total train_iterations %llu\n", __func__, (long long unsigned) train->train_its);
+ printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples);
+ printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens);
+ printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
+ printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + lora.data.size()), (float) (ggml_used_mem(lora.ctx) + lora.data.size()) / (1024.0f*1024.0f));
+
+ if (params.only_write_lora) {
+ save_train_files_data save_data;
+ save_data.fn_checkpoint_out = "";
+ save_data.fn_lora_out = params.fn_lora_out;
+ save_data.pattern_fn_it = params.common.pattern_fn_it;
+ save_data.fn_latest = params.common.fn_latest;
+ save_data.model = &model;
+ save_data.lora = &lora;
+
+ save_train_files(&save_data, train);
+
+ free_train_state(train);
+ ggml_free(lora.ctx);
+ llama_free(lctx);
+ llama_free_model(lmodel);
+ return 0;
+ }
+
+ printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f));
+ printf("%s: opt iter %d\n", __func__, opt->iter);
+
+ int n_tokens = model.hparams.n_ctx;
+ int n_vocab = model.hparams.n_vocab;
+ int n_batch = params.common.n_batch;
+
+
+ std::vector<uint8_t> mem_input_data;
+ std::vector<uint8_t> mem_compute_data;
+
+ // context for input tensors without their data
+ struct ggml_init_params ctx_input_params = {
+ ggml_tensor_overhead() * 2, // mem_size
+ NULL, // mem_buffer
+ true, // no_alloc
+ };
+ struct ggml_context * ctx_input = ggml_init(ctx_input_params);
+
+ // the input tensors
+ struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx_input, GGML_TYPE_I32, n_tokens, n_batch);
+ struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
+
+ // measure required memory for input tensors
+ alloc = ggml_allocr_new_measure(tensor_alignment);
+ ggml_allocr_alloc(alloc, tokens_input);
+ ggml_allocr_alloc(alloc, target_probs);
+ size_t max_input_size = ggml_allocr_max_size(alloc) + tensor_alignment;
+ ggml_allocr_free(alloc);
+ printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
+
+ // allocate input tensors
+ mem_input_data.resize(max_input_size);
+ alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
+ ggml_allocr_alloc(alloc, tokens_input);
+ ggml_allocr_alloc(alloc, target_probs);
+ ggml_allocr_free(alloc);
+
+ // context for compute tensors without their data
+ size_t estimated_compute_size_wo_data = (
+ ggml_tensor_overhead()*GGML_MAX_NODES*2
+ + (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
+ params.common.use_checkpointing ? 3 : 2
+ )
+ );
+ struct ggml_init_params ctx_compute_params = {
+ estimated_compute_size_wo_data, // mem_size
+ NULL, // mem_buffer
+ true, // no_alloc
+ };
+ struct ggml_context * ctx_compute = NULL;
+
+ struct ggml_tensor * loss = NULL;
+ struct ggml_tensor * logits = NULL;
+
+ struct ggml_cgraph * gf = NULL;
+ struct ggml_cgraph * gb = NULL;
+ struct ggml_cgraph * gb_tmp = NULL;
+
+ // measure required memory for compute tensors
+ size_t best_compute_size = SIZE_MAX;
+ enum ggml_cgraph_eval_order best_order = GGML_CGRAPH_EVAL_ORDER_COUNT;
+ // find best evaluation order
+ for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
+ ctx_compute = ggml_init(ctx_compute_params);
+ alloc = ggml_allocr_new_measure(tensor_alignment);
+ gf = ggml_new_graph(ctx_compute);
+ gf->order = (enum ggml_cgraph_eval_order) order;
+ gb = ggml_new_graph(ctx_compute);
+ gb_tmp = params.common.use_checkpointing
+ ? ggml_new_graph(ctx_compute)
+ : NULL;
+ loss = llama_build_lora_finetune_graphs(
+ &model, &lora, alloc, ctx_compute,
+ gf, gb, gb_tmp,
+ &logits, tokens_input, target_probs,
+ n_tokens, n_batch,
+ params.common.use_flash,
+ params.common.use_checkpointing
+ );
+ size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
+ if (max_compute_size < best_compute_size) {
+ best_compute_size = max_compute_size;
+ best_order = gf->order;
+ }
+ ggml_allocr_free(alloc);
+ ggml_free(ctx_compute);
+ }
+ size_t max_compute_size = best_compute_size;
+ printf("%s: compute_size = %zu bytes (%.1f MB)\n", __func__, max_compute_size, (float) max_compute_size / (1024.0f*1024.0f));
+ printf("%s: evaluation order = %s\n", __func__,
+ (best_order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? "LEFT_TO_RIGHT" :
+ (best_order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? "RIGHT_TO_LEFT" :
+ "invalid");
+
+ // allocate compute tensors
+ mem_compute_data.resize(max_compute_size);
+ ctx_compute = ggml_init(ctx_compute_params);
+ alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
+ gf = ggml_new_graph(ctx_compute);
+ gf->order = best_order;
+ gb = ggml_new_graph(ctx_compute);
+ gb_tmp = params.common.use_checkpointing
+ ? ggml_new_graph(ctx_compute)
+ : NULL;
+ loss = llama_build_lora_finetune_graphs(
+ &model, &lora, alloc, ctx_compute,
+ gf, gb, gb_tmp,
+ &logits, tokens_input, target_probs,
+ n_tokens, n_batch,
+ params.common.use_flash,
+ params.common.use_checkpointing
+ );
+ ggml_allocr_free(alloc);
+
+ // tokenize data
+ std::vector<llama_token> train_tokens;
+ std::vector<size_t> train_samples_begin;
+ std::vector<size_t> train_samples_size;
+ printf("%s: tokenize training data\n", __func__);
+ tokenize_file(lctx,
+ params.common.fn_train_data,
+ params.common.sample_start,
+ params.common.include_sample_start,
+ params.common.overlapping_samples,
+ n_tokens,
+ train_tokens,
+ train_samples_begin,
+ train_samples_size);
+ GGML_ASSERT(train_samples_begin.size() == train_samples_size.size());
+
+ printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size());
+
+ std::vector<size_t> token_noccurs;
+ token_noccurs.resize(model.hparams.n_vocab, 0);
+ for (unsigned int i = 0; i < train_tokens.size(); ++i) {
+ ++token_noccurs[train_tokens[i]];
+ }
+ int n_unique_tokens = 0;
+ for (unsigned int i = 0; i < token_noccurs.size(); ++i) {
+ if (token_noccurs[i] == 0) continue;
+ ++n_unique_tokens;
+ }
+ printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
+
+ size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
+ const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
+ if (changed_train_data) {
+ printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
+ }
+ if (params.common.force_reshuffle) {
+ printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
+ }
+ if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) {
+ train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed);
+ train->shuffle_sample_count = train_samples_size.size();
+ train->shuffle_next_sample = 0;
+ train->shuffle_samples_hash = shuffle_samples_hash;
+ }
+ std::vector<size_t> train_shuffled_samples_offs;
+ std::vector<size_t> train_shuffled_samples_begin;
+ std::vector<size_t> train_shuffled_samples_size;
+ train_shuffled_samples_offs.resize(train_samples_begin.size());
+ train_shuffled_samples_begin.resize(train_samples_begin.size());
+ train_shuffled_samples_size.resize(train_samples_size.size());
+ train->shuffle_rng_state_next = shuffle_samples(
+ train->shuffle_rng_state_current,
+ train_shuffled_samples_offs.data(),
+ train_shuffled_samples_begin.data(),
+ train_shuffled_samples_size.data(),
+ train_samples_begin.data(),
+ train_samples_size.data(),
+ train_samples_size.size());
+
+ printf("%s: begin training\n", __func__);
+
+ save_train_files_data save_data;
+ save_data.fn_checkpoint_out = params.common.fn_checkpoint_out;
+ save_data.fn_lora_out = params.fn_lora_out;
+ save_data.pattern_fn_it = params.common.pattern_fn_it;
+ save_data.fn_latest = params.common.fn_latest;
+ save_data.model = &model;
+ save_data.lora = &lora;
+
+ struct train_opt_callback_data opt_cb_data;
+ opt_cb_data.params = ¶ms.common;
+ opt_cb_data.train = train;
+ opt_cb_data.save_cb = &save_train_files;
+ opt_cb_data.save_data = &save_data;
+ opt_cb_data.lctx = lctx;
+ opt_cb_data.last_save_iter = opt->iter;
+ opt_cb_data.tokens_data = train_tokens.data();
+ opt_cb_data.tokens_size = train_tokens.size();
+ opt_cb_data.samples_begin = train_samples_begin.data();
+ opt_cb_data.samples_size = train_samples_size.data();
+ opt_cb_data.shuffled_samples_offs = train_shuffled_samples_offs.data();
+ opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data();
+ opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data();
+ opt_cb_data.samples_count = train_samples_size.size();
+ opt_cb_data.tokens_input = tokens_input;
+ opt_cb_data.target_probs = target_probs;
+ opt_cb_data.first_iter = opt->iter;
+ opt_cb_data.first_epoch = train->train_epochs;
+ opt_cb_data.iter_at_last_epoch = -1;
+ opt_cb_data.last_time = ggml_time_ms();
+ opt_cb_data.millis_per_iter = 0.0;
+
+ // measure required memory for work buffer
+ size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE;
+ printf("%s: work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));
+
+ // context for work buffer
+ struct ggml_init_params ctx_work_params = {
+ max_work_size, // mem_size
+ NULL, // mem_buffer
+ false, // no_alloc
+ };
+ struct ggml_context * ctx_work = ggml_init(ctx_work_params);
+
+ int64_t t0 = ggml_time_ms();
+
+ ggml_opt_resume_g(ctx_work, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data);
+
+ ggml_free(ctx_work);
+ ggml_free(ctx_compute);
+ ggml_free(ctx_input);
+
+ int64_t t1 = ggml_time_ms();
+ printf("%s: total training time: ", __func__);
+ print_duration((double) (t1 - t0));
+ printf("\n");
+
+ int new_iters = opt->iter - opt_cb_data.last_save_iter;
+ if (new_iters > 0) {
+ train->train_its += new_iters;
+ train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
+
+ save_train_files(&save_data, train);
+ opt_cb_data.last_save_iter = opt->iter;
+ }
+
+ ggml_free(opt->ctx);
+ free_train_state(train);
+ ggml_free(lora.ctx);
+ llama_free(lctx);
+ llama_free_model(lmodel);
+ return 0;
+}
invalid_param = true;
break;
}
- params.lora_adapter = argv[i];
+ params.lora_adapter.push_back({argv[i], 1.0f});
+ params.use_mmap = false;
+ }
+ else if (arg == "--lora-scaled")
+ {
+ if (++i >= argc)
+ {
+ invalid_param = true;
+ break;
+ }
+ const char * lora_adapter = argv[i];
+ if (++i >= argc)
+ {
+ invalid_param = true;
+ break;
+ }
+ params.lora_adapter.push_back({lora_adapter, std::stof(argv[i])});
params.use_mmap = false;
}
else if (arg == "--lora-base")
./bin/train-text-from-scratch \
--vocab-model ../models/ggml-vocab-llama.gguf \
--ctx 64 --embd 256 --head 8 --layer 16 \
- --checkpoint-in chk-shakespeare-256x16.gguf \
- --checkpoint-out chk-shakespeare-256x16.gguf \
- --model-out ggml-shakespeare-256x16-f32.gguf \
+ --checkpoint-in chk-shakespeare-256x16-LATEST.gguf \
+ --checkpoint-out chk-shakespeare-256x16-ITERATION.gguf \
+ --model-out ggml-shakespeare-256x16-f32-ITERATION.gguf \
--train-data "shakespeare.txt" \
-t 6 -b 16 --seed 1 --adam-iter 256 \
--no-checkpointing
# predict
./bin/main -m ggml-shakespeare-256x16-f32.gguf
```
+
+Output files will be saved every N iterations (config with `--save-every N`).
+The pattern "ITERATION" in the output filenames will be replaced with the iteration number and "LATEST" for the latest output.
+
+To train GGUF models just pass them to `--checkpoint-in FN`.
LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"
LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"
-LLM_KV_TRAINING_FILE_VERSION = "training.file_version"
-LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"
-LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"
-LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"
+LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"
+LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"
+LLM_KV_TRAINING_TYPE = "training.type"
+LLM_KV_TRAINING_FILE_VERSION = "training.file_version"
+LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"
+LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"
+LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"
class Tensor:
def __init__(self, dtype='f', ne=None):
gguf_writer.add_file_type(gguf.GGMLQuantizationType.F32)
gguf_writer.add_layer_norm_rms_eps(1e-5)
gguf_writer.add_uint32(LLM_KV_TRAINING_FILE_VERSION, 0)
+ gguf_writer.add_string(LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL)
gguf_writer.add_uint32(LLM_KV_TRAINING_ITERATION_COUNT, self.train_its)
gguf_writer.add_uint32(LLM_KV_TRAINING_SAMPLE_COUNT, self.train_samples)
gguf_writer.add_uint32(LLM_KV_TRAINING_TOKEN_COUNT, self.train_tokens)
#include "ggml.h"
#include "ggml-alloc.h"
#include "common.h"
+#include "train.h"
#include "llama.h"
#include <unordered_map>
#include <vector>
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
-struct random_normal_distribution {
- std::mt19937 gen;
- std::normal_distribution<float> rd;
- float min;
- float max;
-};
-
-struct random_uniform_distribution {
- std::mt19937 gen;
- std::uniform_real_distribution<float> rd;
-};
-
-void init_random_normal_distribution(struct random_normal_distribution * rnd, int seed, float mean, float std, float min, float max) {
- rnd->gen = std::mt19937(seed);
- rnd->rd = std::normal_distribution<float>{mean, std};
- rnd->min = min;
- rnd->max = max;
-}
-
-void init_random_uniform_distribution(struct random_uniform_distribution * rnd, int seed, float min, float max) {
- rnd->gen = std::mt19937(seed);
- rnd->rd = std::uniform_real_distribution<float>{min, max};
-}
-
-int clamp(const int v, const int min, const int max) {
- return ((v < min) ? (min) : (v > max) ? (max) : v);
-}
-
-float fclamp(const float v, const float min, const float max) {
- return ((v < min) ? (min) : (v > max) ? (max) : v);
-}
-
-float frand() {
- return (float)rand()/(float)RAND_MAX;
-}
-
-float frand_normal(struct random_normal_distribution * rnd) {
- return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
-}
-
-float frand_uniform(struct random_uniform_distribution * rnd) {
- return rnd->rd(rnd->gen);
-}
-
-struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
- float scale = 1.0f; // xavier
- switch (tensor->n_dims) {
- case 1:
- scale /= sqrtf(tensor->ne[0]);
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
- *dst = scale * frand_normal(rnd);
- }
- break;
- case 2:
- scale /= sqrtf(tensor->ne[0]+tensor->ne[1]);
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
- *dst = scale * frand_normal(rnd);
- }
- }
- break;
- case 3:
- scale /= sqrtf(tensor->ne[0]+tensor->ne[1]);
- for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
- *dst = scale * frand_normal(rnd);
- }
- }
- }
- break;
- case 4:
- scale /= sqrtf(tensor->ne[0]+tensor->ne[1]);
- for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
- for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
- *dst = scale * frand_normal(rnd);
- }
- }
- }
- }
- break;
- default:
- assert(false);
- };
- return tensor;
-}
-
-struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) {
- switch (tensor->n_dims) {
- case 1:
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
- *dst = frand_uniform(rnd);
- }
- break;
- case 2:
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
- *dst = frand_uniform(rnd);
- }
- }
- break;
- case 3:
- for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
- *dst = frand_uniform(rnd);
- }
- }
- }
- break;
- case 4:
- for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
- for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
- for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
- for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
- float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
- *dst = frand_uniform(rnd);
- }
- }
- }
- }
- break;
- default:
- assert(false);
- };
- return tensor;
-}
+static const size_t tensor_alignment = 32;
struct my_llama_hparams {
uint32_t n_vocab = 32000;
uint32_t n_rot = 64;
uint32_t n_ff = 11008;
- // float f_norm_eps = 1e-5; // falcon
- float f_norm_rms_eps = 1e-5; // llama
+ // float f_norm_eps = 1e-5f; // falcon
+ float f_norm_rms_eps = 1e-5f; // llama
float rope_freq_base = 10000.0f;
float rope_freq_scale = 1.0f;
struct my_llama_model {
struct ggml_context * ctx = NULL;
+ std::vector<uint8_t> data;
my_llama_hparams hparams;
struct ggml_tensor * output;
std::vector<my_llama_layer> layers;
-
- uint32_t train_its = 0;
- uint32_t train_samples = 0;
- uint32_t train_tokens = 0;
};
-// gguf constants
-const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type";
-const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam";
-const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs";
-const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version";
-const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count";
-const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count";
-const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count";
-const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized";
-const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss";
-const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss";
-const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count";
-const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count";
-const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss";
-const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step";
-const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j";
-const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k";
-const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end";
-const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count";
-
-const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments";
-const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments";
-const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values";
-
-const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters";
-const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters";
-const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients";
-const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients";
-const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction";
-const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values";
-const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha";
-const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys";
-const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
-const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
-
-const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
-const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
-const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
-const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
-
// gguf constants (sync with gguf.py)
-
-const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture";
-const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type";
-
-const char * LLM_KV_CONTEXT_LENGTH = "%s.context_length";
-const char * LLM_KV_EMBEDDING_LENGTH = "%s.embedding_length";
-const char * LLM_KV_BLOCK_COUNT = "%s.block_count";
-const char * LLM_KV_FEED_FORWARD_LENGTH = "%s.feed_forward_length";
-const char * LLM_KV_ATTENTION_HEAD_COUNT = "%s.attention.head_count";
-const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon";
-const char * LLM_KV_ROPE_DIMENSION_COUNT = "%s.rope.dimension_count";
-const char * LLM_KV_ROPE_FREQ_BASE = "%s.rope.freq_base"; // TODO load in llama.cpp
-const char * LLM_KV_ROPE_SCALE_LINEAR = "%s.rope.scale_linear";
-
-const char * LLM_KV_TOKENIZER_MODEL = "tokenizer.ggml.model";
-const char * LLM_KV_TOKENIZER_LIST = "tokenizer.ggml.tokens";
-const char * LLM_KV_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type";
-const char * LLM_KV_TOKENIZER_SCORES = "tokenizer.ggml.scores";
-const char * LLM_KV_TOKENIZER_MERGES = "tokenizer.ggml.merges";
-const char * LLM_KV_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id";
-const char * LLM_KV_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id";
-const char * LLM_KV_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id";
-const char * LLM_KV_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id";
-const char * LLM_KV_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id";
-
-const char * LLM_TENSOR_TOKEN_EMBD = "token_embd";
-const char * LLM_TENSOR_OUTPUT_NORM = "output_norm";
-const char * LLM_TENSOR_OUTPUT = "output";
-const char * LLM_TENSOR_ATTN_NORM = "blk.%d.attn_norm";
-const char * LLM_TENSOR_ATTN_Q = "blk.%d.attn_q";
-const char * LLM_TENSOR_ATTN_K = "blk.%d.attn_k";
-const char * LLM_TENSOR_ATTN_V = "blk.%d.attn_v";
-const char * LLM_TENSOR_ATTN_OUT = "blk.%d.attn_output";
-const char * LLM_TENSOR_FFN_NORM = "blk.%d.ffn_norm";
-const char * LLM_TENSOR_FFN_GATE = "blk.%d.ffn_gate";
-const char * LLM_TENSOR_FFN_DOWN = "blk.%d.ffn_down";
-const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up";
-
-void print_params(struct my_llama_hparams * params) {
+static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
+static const char * LLM_KV_TRAINING_TYPE = "training.type";
+
+static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture";
+static const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type";
+
+static const char * LLM_KV_CONTEXT_LENGTH = "%s.context_length";
+static const char * LLM_KV_EMBEDDING_LENGTH = "%s.embedding_length";
+static const char * LLM_KV_BLOCK_COUNT = "%s.block_count";
+static const char * LLM_KV_FEED_FORWARD_LENGTH = "%s.feed_forward_length";
+static const char * LLM_KV_ATTENTION_HEAD_COUNT = "%s.attention.head_count";
+static const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon";
+static const char * LLM_KV_ROPE_DIMENSION_COUNT = "%s.rope.dimension_count";
+static const char * LLM_KV_ROPE_FREQ_BASE = "%s.rope.freq_base"; // TODO load in llama.cpp
+static const char * LLM_KV_ROPE_SCALE_LINEAR = "%s.rope.scale_linear";
+
+static const char * LLM_KV_TOKENIZER_MODEL = "tokenizer.ggml.model";
+static const char * LLM_KV_TOKENIZER_LIST = "tokenizer.ggml.tokens";
+static const char * LLM_KV_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type";
+static const char * LLM_KV_TOKENIZER_SCORES = "tokenizer.ggml.scores";
+static const char * LLM_KV_TOKENIZER_MERGES = "tokenizer.ggml.merges";
+static const char * LLM_KV_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id";
+static const char * LLM_KV_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id";
+static const char * LLM_KV_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id";
+static const char * LLM_KV_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id";
+static const char * LLM_KV_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id";
+
+static const char * LLM_TENSOR_TOKEN_EMBD = "token_embd";
+static const char * LLM_TENSOR_OUTPUT_NORM = "output_norm";
+static const char * LLM_TENSOR_OUTPUT = "output";
+static const char * LLM_TENSOR_ATTN_NORM = "blk.%d.attn_norm";
+static const char * LLM_TENSOR_ATTN_Q = "blk.%d.attn_q";
+static const char * LLM_TENSOR_ATTN_K = "blk.%d.attn_k";
+static const char * LLM_TENSOR_ATTN_V = "blk.%d.attn_v";
+static const char * LLM_TENSOR_ATTN_OUT = "blk.%d.attn_output";
+static const char * LLM_TENSOR_FFN_NORM = "blk.%d.ffn_norm";
+static const char * LLM_TENSOR_FFN_GATE = "blk.%d.ffn_gate";
+static const char * LLM_TENSOR_FFN_DOWN = "blk.%d.ffn_down";
+static const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up";
+
+static void print_params(struct my_llama_hparams * params) {
printf("%s: n_vocab: %d\n", __func__, params->n_vocab);
printf("%s: n_ctx: %d\n", __func__, params->n_ctx);
printf("%s: n_embd: %d\n", __func__, params->n_embd);
printf("%s: n_rot: %d\n", __func__, params->n_rot);
}
-void init_model(struct my_llama_model * model) {
+static void set_param_model(struct my_llama_model * model) {
+ const auto& hparams = model->hparams;
+
+ const uint32_t n_layer = hparams.n_layer;
+
+ struct ggml_context* ctx = model->ctx;
+
+ ggml_set_param(ctx, model->tok_embeddings);
+ ggml_set_param(ctx, model->norm);
+ ggml_set_param(ctx, model->output);
+
+ for (uint32_t i = 0; i < n_layer; ++i) {
+ auto & layer = model->layers[i];
+
+ ggml_set_param(ctx, layer.attention_norm);
+ ggml_set_param(ctx, layer.wq);
+ ggml_set_param(ctx, layer.wk);
+ ggml_set_param(ctx, layer.wv);
+ ggml_set_param(ctx, layer.wo);
+ ggml_set_param(ctx, layer.ffn_norm);
+ ggml_set_param(ctx, layer.w1);
+ ggml_set_param(ctx, layer.w2);
+ ggml_set_param(ctx, layer.w3);
+ }
+}
+
+static void alloc_model(struct ggml_allocr * alloc, struct my_llama_model * model) {
+ ggml_allocr_alloc(alloc, model->tok_embeddings);
+ ggml_allocr_alloc(alloc, model->norm);
+ ggml_allocr_alloc(alloc, model->output);
+ for (uint32_t i = 0; i < model->layers.size(); ++i) {
+ auto & layer = model->layers[i];
+ ggml_allocr_alloc(alloc, layer.attention_norm);
+ ggml_allocr_alloc(alloc, layer.wq);
+ ggml_allocr_alloc(alloc, layer.wk);
+ ggml_allocr_alloc(alloc, layer.wv);
+ ggml_allocr_alloc(alloc, layer.wo);
+ ggml_allocr_alloc(alloc, layer.ffn_norm);
+ ggml_allocr_alloc(alloc, layer.w1);
+ ggml_allocr_alloc(alloc, layer.w2);
+ ggml_allocr_alloc(alloc, layer.w3);
+ }
+ ggml_allocr_alloc(alloc, model->tok_embeddings->grad);
+ ggml_allocr_alloc(alloc, model->norm->grad);
+ ggml_allocr_alloc(alloc, model->output->grad);
+ for (uint32_t i = 0; i < model->layers.size(); ++i) {
+ auto & layer = model->layers[i];
+ ggml_allocr_alloc(alloc, layer.attention_norm->grad);
+ ggml_allocr_alloc(alloc, layer.wq->grad);
+ ggml_allocr_alloc(alloc, layer.wk->grad);
+ ggml_allocr_alloc(alloc, layer.wv->grad);
+ ggml_allocr_alloc(alloc, layer.wo->grad);
+ ggml_allocr_alloc(alloc, layer.ffn_norm->grad);
+ ggml_allocr_alloc(alloc, layer.w1->grad);
+ ggml_allocr_alloc(alloc, layer.w2->grad);
+ ggml_allocr_alloc(alloc, layer.w3->grad);
+ }
+}
+
+static void init_model(struct my_llama_model * model) {
const auto & hparams = model->hparams;
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_vocab = hparams.n_vocab;
const uint32_t n_ff = hparams.n_ff;
- struct ggml_context * ctx = model->ctx;
-
- model->train_its = 0;
- model->train_samples = 0;
- model->train_tokens = 0;
std::vector<char> tn_buf;
tn_buf.resize(GGML_MAX_NAME);
return tn_buf.data();
};
+ // context for model tensors without their data
+ struct ggml_init_params ctx_model_params;
+ ctx_model_params.mem_size = ggml_tensor_overhead()*2*(6 + n_layer*18);
+ ctx_model_params.mem_buffer = NULL;
+ ctx_model_params.no_alloc = true;
+
+ struct ggml_context * ctx = ggml_init(ctx_model_params);
+ model->ctx = ctx;
+
model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
ggml_set_name(layer.w2, tni(LLM_TENSOR_FFN_DOWN, i));
ggml_set_name(layer.w3, tni(LLM_TENSOR_FFN_UP, i));
}
-}
-void set_param_model(struct my_llama_model * model) {
- const auto& hparams = model->hparams;
+ set_param_model(model);
- const uint32_t n_layer = hparams.n_layer;
-
- struct ggml_context* ctx = model->ctx;
+ // measure data size
+ struct ggml_allocr * alloc = NULL;
+ alloc = ggml_allocr_new_measure(tensor_alignment);
+ alloc_model(alloc, model);
- ggml_set_param(ctx, model->tok_embeddings);
- ggml_set_param(ctx, model->norm);
- ggml_set_param(ctx, model->output);
-
- for (uint32_t i = 0; i < n_layer; ++i) {
- auto & layer = model->layers[i];
-
- ggml_set_param(ctx, layer.attention_norm);
- ggml_set_param(ctx, layer.wq);
- ggml_set_param(ctx, layer.wk);
- ggml_set_param(ctx, layer.wv);
- ggml_set_param(ctx, layer.wo);
- ggml_set_param(ctx, layer.ffn_norm);
- ggml_set_param(ctx, layer.w1);
- ggml_set_param(ctx, layer.w2);
- ggml_set_param(ctx, layer.w3);
- }
+ // allocate data
+ model->data.resize(ggml_allocr_max_size(alloc) + tensor_alignment);
+ ggml_allocr_free(alloc);
+ alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
+ alloc_model(alloc, model);
+ ggml_allocr_free(alloc);
}
-void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
+static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
const auto & hparams = model->hparams;
const uint32_t n_layer = hparams.n_layer;
- struct random_normal_distribution rnd;
- init_random_normal_distribution(&rnd, seed, mean, std, min, max);
+ struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
- randomize_tensor_normal(model->tok_embeddings, &rnd);
- randomize_tensor_normal(model->norm, &rnd);
- randomize_tensor_normal(model->output, &rnd);
+ randomize_tensor_normal(model->tok_embeddings, rnd);
+ randomize_tensor_normal(model->norm, rnd);
+ randomize_tensor_normal(model->output, rnd);
for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = model->layers[i];
- randomize_tensor_normal(layer.attention_norm, &rnd);
-
- randomize_tensor_normal(layer.wq, &rnd);
- randomize_tensor_normal(layer.wk, &rnd);
- randomize_tensor_normal(layer.wv, &rnd);
- randomize_tensor_normal(layer.wo, &rnd);
-
- randomize_tensor_normal(layer.ffn_norm, &rnd);
-
- randomize_tensor_normal(layer.w1, &rnd);
- randomize_tensor_normal(layer.w2, &rnd);
- randomize_tensor_normal(layer.w3, &rnd);
- }
-}
-
-void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
- GGML_ASSERT(tensor->n_dims == 1);
- GGML_ASSERT(tensor->ne[0] == ne0);
-}
-
-void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
- GGML_ASSERT(tensor->n_dims == 2);
- GGML_ASSERT(tensor->ne[0] == ne0);
- GGML_ASSERT(tensor->ne[1] == ne1);
-}
-
-void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
- GGML_ASSERT(tensor->n_dims == 3);
- GGML_ASSERT(tensor->ne[0] == ne0);
- GGML_ASSERT(tensor->ne[1] == ne1);
- GGML_ASSERT(tensor->ne[2] == ne2);
-}
-
-void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
- GGML_ASSERT(tensor->n_dims == 4);
- GGML_ASSERT(tensor->ne[0] == ne0);
- GGML_ASSERT(tensor->ne[1] == ne1);
- GGML_ASSERT(tensor->ne[2] == ne2);
- GGML_ASSERT(tensor->ne[3] == ne3);
-}
-
-static size_t hash(void * p) {
- return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
-}
-
-static size_t hash_find(void * hash_table[], void * p) {
- size_t h = hash(p);
-
- // linear probing
- size_t i = h;
- while (hash_table[i] != NULL && hash_table[i] != p) {
- i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
- if (i == h) {
- // visited all hash table entries -> not found
- return GGML_GRAPH_HASHTABLE_SIZE;
- }
- }
- return i;
-}
-
-static bool hash_insert(void * hash_table[], void * p) {
- //size_t h = hash(p);
- size_t i = hash_find(hash_table, p);
-
- GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
-
- if (hash_table[i] == p) {
- return true;
- }
-
- // insert
- GGML_ASSERT(hash_table[i] == NULL);
- hash_table[i] = p;
- return false;
-}
-
-static bool hash_contains(void * hash_table[], void * p) {
- size_t i = hash_find(hash_table, p);
- return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
-}
-
-struct hash_map {
- void * keys[GGML_GRAPH_HASHTABLE_SIZE];
- void * vals[GGML_GRAPH_HASHTABLE_SIZE];
-};
-//static const size_t HASH_MAP_SIZE = sizeof(struct hash_map);
-
-struct hash_map * new_hash_map() {
- struct hash_map * result = new struct hash_map;
- for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
- result->keys[i] = NULL;
- result->vals[i] = NULL;
- }
- return result;
-};
-
-void free_hash_map(struct hash_map * map) {
- delete map;
-}
-
-static bool ggml_is_view(struct ggml_tensor * t) {
- return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
- t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
-}
-
-static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
- switch (t->op) {
- case GGML_OP_PERMUTE:
- case GGML_OP_RESHAPE:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_VIEW:
- return t->src[0];
- case GGML_OP_CPY:
- return t->src[1];
- default:
- return NULL;
- }
-}
-
-static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
- struct ggml_tensor * parent = t;
- do {
- parent = get_view_parent(parent);
- } while (ggml_is_view(parent));
- return parent;
-}
-
-struct ggml_tensor * ggml_recompute_graph_node(
- struct ggml_context * ctx,
- struct ggml_cgraph * graph,
- struct hash_map * replacements,
- struct ggml_tensor * node) {
-
- if (node == NULL) {
- return NULL;
- }
-
- if (node->is_param) {
- return node;
- }
-
- if (!hash_contains(graph->visited_hash_table, node)) {
- return node;
- }
-
- int count_children = 0;
- for (int k = 0; k < GGML_MAX_SRC; ++k) {
- if (node->src[k]) {
- ++count_children;
- }
- }
-
- if (count_children == 0) {
- return node;
- }
+ randomize_tensor_normal(layer.attention_norm, rnd);
- size_t i = hash_find(replacements->keys, node);
- GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
- if (replacements->keys[i] == node) {
- return (struct ggml_tensor *) replacements->vals[i];
- }
-
- struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
+ randomize_tensor_normal(layer.wq, rnd);
+ randomize_tensor_normal(layer.wk, rnd);
+ randomize_tensor_normal(layer.wv, rnd);
+ randomize_tensor_normal(layer.wo, rnd);
- // insert clone into replacements
- GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
- replacements->keys[i] = node;
- replacements->vals[i] = clone;
+ randomize_tensor_normal(layer.ffn_norm, rnd);
- clone->op = node->op;
- clone->grad = node->grad;
- clone->is_param = node->is_param;
- clone->extra = node->extra;
- for (int k = 0; k < GGML_MAX_DIMS; ++k) {
- clone->nb[k] = node->nb[k];
- }
- for (int k = 0; k < GGML_MAX_SRC; ++k) {
- clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
- }
- if (ggml_is_view(clone)) {
- struct ggml_tensor * source = get_view_source(clone);
- GGML_ASSERT(source != NULL);
- clone->data = source->data;
+ randomize_tensor_normal(layer.w1, rnd);
+ randomize_tensor_normal(layer.w2, rnd);
+ randomize_tensor_normal(layer.w3, rnd);
}
- GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
- GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
- memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
- ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
-
- return clone;
-};
-
-void ggml_build_backward_gradient_checkpointing(
- struct ggml_context * ctx,
- struct ggml_cgraph * gf,
- struct ggml_cgraph * gb,
- struct ggml_cgraph * gb_tmp,
- struct ggml_tensor * * checkpoints,
- int n_checkpoints) {
- *gb_tmp = *gf;
- ggml_build_backward_expand(ctx, gf, gb_tmp, true);
-
- if (n_checkpoints <= 0) {
- *gb = *gb_tmp;
- return;
- }
-
- struct hash_map * replacements = new_hash_map();
-
- // insert checkpoints in replacements
- for (int i = 0; i < n_checkpoints; ++i) {
- size_t k = hash_find(replacements->keys, checkpoints[i]);
- GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
- GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
- replacements->keys[k] = checkpoints[i];
- replacements->vals[k] = checkpoints[i];
- }
-
- *gb = *gf;
- // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
- // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
- // by recomputing them from checkpoints
- for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
- struct ggml_tensor * node = gb_tmp->nodes[i];
- for (int k = 0; k < GGML_MAX_SRC; ++k) {
- // insert new tensors recomputing src, reusing already made replacements,
- // remember replacements: remember new tensors with mapping from corresponding gf nodes
- // recurse for input tensors,
- // unless (i.e. terminating when) input tensors are checkpoints
- node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
- }
- // insert rewritten backward node with replacements made into resulting backward graph gb
- ggml_build_forward_expand(gb, node);
- }
-
- free_hash_map(replacements);
+ free_random_normal_distribution(rnd);
}
-struct ggml_tensor * llama_build_train_graphs(
+static struct ggml_tensor * llama_build_train_graphs(
struct my_llama_model * model,
struct ggml_allocr * alloc,
struct ggml_context * ctx,
checkpoints.push_back(t00);
checkpoints.push_back(t01);
- struct ggml_tensor * kv_scale;
+ struct ggml_tensor * kv_scale = NULL;
if (!enable_flash_attn) {
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
}
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
// KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
- GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad));
+ GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
+
ggml_allocr_alloc(alloc, t36->grad);
- // gradient tensors (will be set to zero by ggml_graph_reset)
- // pinning these produces large unnecessary memory overhead, which will be resolved by PR 2632
- for (int i = 0; i < gf->n_nodes; ++i) {
- if (!gf->grads[i]) continue;
- if (gf->grads[i]->data == NULL && !ggml_is_view(gf->grads[i])) {
- ggml_allocr_alloc(alloc, gf->grads[i]);
- }
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, gf->grads[i], one));
- }
+
// allocating checkpoints in one block to reduce memory fragmentation
// note: they will be freed in reverse order
for (int i = 0; i < (int) checkpoints.size(); ++i) {
- if (checkpoints[i]->data == NULL && !ggml_is_view(checkpoints[i])) {
+ if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
ggml_allocr_alloc(alloc, checkpoints[i]);
}
}
return t36;
}
-void set_f32_3d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int64_t i2, float value) {
- float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
- *ptr = value;
-}
-
-void set_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, float value) {
- float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
- *ptr = value;
-}
-
-void set_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int32_t value) {
- int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
- *ptr = value;
-}
-
-float get_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
- float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
- return *ptr;
-}
-
-int32_t get_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
- int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
- return *ptr;
-}
-
-void print_row(struct ggml_tensor * probs, int i) {
- for (int k = 0; k < probs->ne[0]; ++k) {
- float p = get_f32_2d(probs, k, i);
- printf(" %.2f", p);
- }
- printf("\n");
-}
-
-void print_matrix(struct ggml_tensor * probs) {
- assert(probs->n_dims == 2);
- for (int i = 0; i < probs->ne[1]; ++i) {
- for (int k = 0; k < probs->ne[0]; ++k) {
- float p = get_f32_2d(probs, k, i);
- printf(" %.2f", p);
- }
- printf("\n");
- }
-}
-
-void get_example_targets(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
- int n_tokens = tokens_input->ne[0];
- int n_vocab = target_logits->ne[0];
-
- size_t sample = train_samples[example_id % n_train_samples];
- GGML_ASSERT(sample+n_tokens-1 < n_train_data);
-
- ggml_set_f32(target_logits, -1.0f/n_vocab);
- ggml_set_f32(target_probs, 0.0f);
- ggml_set_i32_1d(tokens_input, 0, llama_token_bos(lctx));
- for (int i=1; i<n_tokens+1; ++i) {
- int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
- set_f32_2d(target_logits, token, i-1, +1.0f);
- set_f32_2d(target_probs, token, i-1, +1.0f);
- if (i<n_tokens) {
- ggml_set_i32_1d(tokens_input, i, token);
- }
- }
-}
-
-void get_example_targets_batch(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
- GGML_ASSERT(tokens_input->n_dims == 2);
- GGML_ASSERT(target_logits->n_dims == 3);
- GGML_ASSERT(target_probs->n_dims == 3);
- int n_vocab = target_logits->ne[0];
- int n_tokens = tokens_input->ne[0];
- int n_batch = tokens_input->ne[1];
- GGML_ASSERT(n_tokens == target_logits->ne[1]);
- GGML_ASSERT(n_batch == target_logits->ne[2]);
- GGML_ASSERT(n_vocab == target_probs->ne[0]);
- GGML_ASSERT(n_tokens == target_probs->ne[1]);
- GGML_ASSERT(n_batch == target_probs->ne[2]);
-
- ggml_set_f32(target_logits, -1.0f/n_vocab);
- ggml_set_f32(target_probs, 0.0f);
- // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
- for (int k=0; k<n_batch; ++k) {
- // printf("%s: batch %d\n", __func__, k);
- size_t sample_idx = (example_id*n_batch + k) % n_train_samples;
- size_t sample = train_samples[sample_idx];
- // printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
- GGML_ASSERT(sample+n_tokens-1 < n_train_data);
-
- set_i32_2d(tokens_input, 0, k, llama_token_bos(lctx));
- for (int i=1; i<n_tokens+1; ++i) {
- int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
- set_f32_3d(target_logits, token, i-1, k, +1.0f);
- set_f32_3d(target_probs, token, i-1, k, +1.0f);
- if (i<n_tokens) {
- set_i32_2d(tokens_input, i, k, token);
- }
- }
- }
-}
-
-int tokenize_file(struct llama_context * lctx, const char * filename, std::vector<llama_token>& out) {
- FILE * fp = std::fopen(filename, "rb");
- if (fp == NULL) {
- return 0;
- }
-
-#ifdef _WIN32
- GGML_ASSERT(_fseeki64(fp, (__int64) 0, SEEK_END) == 0);
-#else
- GGML_ASSERT(std::fseek(fp, (long) 0, SEEK_END) == 0);
-#endif
-
- size_t size = 0;
-#ifdef _WIN32
- __int64 ret = _ftelli64(fp);
- size = ret;
-#else
- long ret = std::ftell(fp);
- size = ret;
-#endif
-
-#ifdef _WIN32
- GGML_ASSERT(_fseeki64(fp, (__int64) 0, SEEK_SET) == 0);
-#else
- GGML_ASSERT(std::fseek(fp, (long) 0, SEEK_SET) == 0);
-#endif
-
- std::vector<char> buf;
- buf.resize(size+1);
- out.resize(size+1);
-
- if (std::fread(buf.data(), size, 1, fp) != 1) {
- die("unexpectedly reached end of file");
- }
- if (ferror(fp)) {
- die_fmt("fread failed: %s", strerror(errno));
- }
-
- buf[size] = '\0';
-
- int n_tokens = llama_tokenize(lctx, buf.data(), buf.size(), out.data(), out.size(), false);
- if (n_tokens < 0) {
- out.resize(-n_tokens);
- n_tokens = llama_tokenize(lctx, buf.data(), buf.size(), out.data(), out.size(), false);
- }
- GGML_ASSERT(n_tokens >= 0);
- out.resize(n_tokens);
-
- bool verify = false;
- if (verify) {
- const char * in = buf.data();
- const char * end = buf.data() + buf.size();
- for (int i = 0; i < (int) out.size(); ++i) {
- std::string s = llama_token_to_piece(lctx, out[i]);
- int len = s.length();
- if (in >= end) {
- printf("%s: unexpected end of original text.\n", __func__);
- break;
- }
- const bool matches = (strncmp(in, s.c_str(), len) == 0);
- if (matches) {
- in += len;
- } else {
- printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s.c_str());
- }
- }
- }
-
- return n_tokens;
-}
-
-void shuffle_ints(int * begin, int * end) {
- if (end <= begin) return;
- int max=begin[0];
- for (int i=1; i<end-begin; ++i) {
- if (begin[i] > max) {
- max = begin[i];
- }
- }
- std::vector<float> vals;
- vals.resize(max+1);
- for (int i=0; i<max+1; ++i) {
- vals[i] = frand();
- }
- std::sort(begin, end, [&vals](int a, int b){
- return vals.at(a) < vals.at(b);
- });
-}
-
#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
{ \
const std::string skey(key); \
} \
}
-
-bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) {
- GGML_ASSERT(a != NULL);
- GGML_ASSERT(b != NULL);
- GGML_ASSERT(a->type == b->type);
- GGML_ASSERT(ggml_are_same_shape(a, b));
- GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b));
-
- return true;
-}
-
-void read_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) {
- if (dst == NULL) {
- return;
- }
- struct ggml_tensor * t = ggml_get_tensor(ctx, name);
- GGML_ASSERT(are_same_layout(dst, t));
- memcpy(dst->data, t->data, ggml_nbytes(t));
-
- if (strlen(ggml_get_name(dst)) == 0) {
- ggml_set_name(dst, name);
- }
-}
-
-void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
- // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
-
- uint32_t file_version;
- GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION);
- GGML_ASSERT(file_version == 0);
-
- GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
- GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
- GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
-
- uint64_t nx;
- GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT);
- opt->nx = (size_t) nx;
-
- // don't call ggml_opt_init until optimizer type and optimizer specific parameters are know
-
- std::string opt_type;
- GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
- if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
- opt->params.type = GGML_OPT_ADAM;
-
- GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
- GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
- GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT);
-
- GGML_ASSERT(opt->ctx != NULL);
- ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
-
- read_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
- read_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
- read_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
- } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
- opt->params.type = GGML_OPT_LBFGS;
-
- GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
- GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
- GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP);
- GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J);
- GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K);
- GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END);
- GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT);
-
- GGML_ASSERT(opt->ctx != NULL);
- ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
-
- read_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
- read_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
- read_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
- read_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
- read_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
- read_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
- read_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
- read_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
- read_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
- read_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
- } else {
- die("unknown optimizer type");
- }
-}
-
-void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) {
- gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0);
- gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past);
- gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx);
- gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
- gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
-
- switch (opt->params.type) {
- case GGML_OPT_ADAM:
- {
- gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
- gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best);
- gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev);
- gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement);
-
- ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
- ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
- if (opt->adam.pf) {
- ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
- }
-
- gguf_add_tensor(fctx, opt->adam.m);
- gguf_add_tensor(fctx, opt->adam.v);
- if (opt->adam.pf) {
- gguf_add_tensor(fctx, opt->adam.pf);
- }
- } break;
- case GGML_OPT_LBFGS:
- {
- gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
- gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);
- gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best);
- gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step);
- gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j);
- gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k);
- gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end);
- gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement);
-
- ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
- ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
- ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
- ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
- ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
- if (opt->lbfgs.pf) {
- ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
- }
- ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
- ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
- ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
- ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
-
- gguf_add_tensor(fctx, opt->lbfgs.x);
- gguf_add_tensor(fctx, opt->lbfgs.xp);
- gguf_add_tensor(fctx, opt->lbfgs.g);
- gguf_add_tensor(fctx, opt->lbfgs.gp);
- gguf_add_tensor(fctx, opt->lbfgs.d);
- if (opt->lbfgs.pf) {
- gguf_add_tensor(fctx, opt->lbfgs.pf);
- }
- gguf_add_tensor(fctx, opt->lbfgs.lmal);
- gguf_add_tensor(fctx, opt->lbfgs.lmys);
- gguf_add_tensor(fctx, opt->lbfgs.lms);
- gguf_add_tensor(fctx, opt->lbfgs.lmy);
- } break;
- }
-}
-
-void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model) {
+static void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model) {
// NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
std::string arch;
init_model(model);
- read_tensor_by_name(model->tok_embeddings, f_ggml_ctx, tn(LLM_TENSOR_TOKEN_EMBD));
- read_tensor_by_name(model->norm, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT_NORM));
- read_tensor_by_name(model->output, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT));
+ copy_tensor_by_name(model->tok_embeddings, f_ggml_ctx, tn(LLM_TENSOR_TOKEN_EMBD));
+ copy_tensor_by_name(model->norm, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT_NORM));
+ copy_tensor_by_name(model->output, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT));
for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
auto & layer = model->layers[i];
- read_tensor_by_name(layer.attention_norm, f_ggml_ctx, tni(LLM_TENSOR_ATTN_NORM, i));
- read_tensor_by_name(layer.wq, f_ggml_ctx, tni(LLM_TENSOR_ATTN_Q, i));
- read_tensor_by_name(layer.wk, f_ggml_ctx, tni(LLM_TENSOR_ATTN_K, i));
- read_tensor_by_name(layer.wv, f_ggml_ctx, tni(LLM_TENSOR_ATTN_V, i));
- read_tensor_by_name(layer.wo, f_ggml_ctx, tni(LLM_TENSOR_ATTN_OUT, i));
- read_tensor_by_name(layer.ffn_norm, f_ggml_ctx, tni(LLM_TENSOR_FFN_NORM, i));
- read_tensor_by_name(layer.w1, f_ggml_ctx, tni(LLM_TENSOR_FFN_GATE, i));
- read_tensor_by_name(layer.w2, f_ggml_ctx, tni(LLM_TENSOR_FFN_DOWN, i));
- read_tensor_by_name(layer.w3, f_ggml_ctx, tni(LLM_TENSOR_FFN_UP, i));
+ copy_tensor_by_name(layer.attention_norm, f_ggml_ctx, tni(LLM_TENSOR_ATTN_NORM, i));
+ copy_tensor_by_name(layer.wq, f_ggml_ctx, tni(LLM_TENSOR_ATTN_Q, i));
+ copy_tensor_by_name(layer.wk, f_ggml_ctx, tni(LLM_TENSOR_ATTN_K, i));
+ copy_tensor_by_name(layer.wv, f_ggml_ctx, tni(LLM_TENSOR_ATTN_V, i));
+ copy_tensor_by_name(layer.wo, f_ggml_ctx, tni(LLM_TENSOR_ATTN_OUT, i));
+ copy_tensor_by_name(layer.ffn_norm, f_ggml_ctx, tni(LLM_TENSOR_FFN_NORM, i));
+ copy_tensor_by_name(layer.w1, f_ggml_ctx, tni(LLM_TENSOR_FFN_GATE, i));
+ copy_tensor_by_name(layer.w2, f_ggml_ctx, tni(LLM_TENSOR_FFN_DOWN, i));
+ copy_tensor_by_name(layer.w3, f_ggml_ctx, tni(LLM_TENSOR_FFN_UP, i));
}
}
-void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model) {
+static void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model) {
const char * arch = "llama";
enum llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
}
}
-void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model) {
+static void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model) {
+ printf("%s: saving to %s\n", __func__, filename);
struct gguf_context * fctx = gguf_init_empty();
save_llama_model_gguf(fctx, fn_vocab_model, model);
gguf_free(fctx);
}
-void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct ggml_opt_context * opt) {
+static void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct train_state * train) {
load_llama_model_gguf(fctx, f_ggml_ctx, model);
-
- uint32_t file_version;
- GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
- GGML_ASSERT(file_version == 0);
-
- GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
- GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
- GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
-
- load_opt_context_gguf(fctx, f_ggml_ctx, opt);
+ if (load_train_state_gguf(fctx, f_ggml_ctx, train)) {
+ std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL;
+ GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
+ GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
+ } else {
+ printf("%s: loaded llama model as checkpoint\n", __func__);
+ }
}
-void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) {
+static void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) {
+ gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
save_llama_model_gguf(fctx, fn_vocab_model, model);
-
- gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 0);
- gguf_set_val_u32(fctx, LLM_KV_TRAINING_ITERATION_COUNT, model->train_its);
- gguf_set_val_u32(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, model->train_samples);
- gguf_set_val_u32(fctx, LLM_KV_TRAINING_TOKEN_COUNT, model->train_tokens);
-
- save_opt_context_gguf(fctx, opt);
+ save_train_state_gguf(fctx, train);
}
-bool load_checkpoint_file(const char * filename, struct my_llama_model * model, struct ggml_opt_context * opt) {
+static bool load_checkpoint_file(const char * filename, struct my_llama_model * model, struct train_state * train) {
struct ggml_context * f_ggml_ctx;
struct gguf_init_params params;
params.no_alloc = false;
return false;
}
- load_checkpoint_gguf(fctx, f_ggml_ctx, model, opt);
+ load_checkpoint_gguf(fctx, f_ggml_ctx, model, train);
return true;
}
-void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) {
+static void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) {
+ printf("%s: saving to %s\n", __func__, filename);
struct gguf_context * fctx = gguf_init_empty();
- save_checkpoint_gguf(fctx, fn_vocab_model, model, opt);
+ save_checkpoint_gguf(fctx, fn_vocab_model, model, train);
// write file
const bool only_meta = false;
gguf_free(fctx);
}
-float cosine_decay(const int decay_steps, const float minimum, int step) {
- if (step > decay_steps) {
- step = decay_steps;
- }
- const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
- const float decay = (1 - minimum)*cosine_decay + minimum;
- return decay;
-}
-
-float cosine_decay_restart(int decay_steps, const float minimum, int step, float restart_step_mult, bool enable_restart) {
- if (enable_restart) {
- while (step > decay_steps) {
- step -= decay_steps;
- decay_steps = (int) restart_step_mult * decay_steps;
- }
- }
- return cosine_decay(decay_steps, minimum, step);
-}
-
struct train_params {
+ struct train_params_common common;
+
const char * fn_vocab_model;
- const char * fn_train_data;
- const char * fn_checkpoint_in;
- const char * fn_checkpoint_out;
const char * fn_model_out;
- uint32_t seed;
+ bool only_write_model;
int n_ctx;
int n_embd;
int n_layer;
int n_ff;
- int n_threads;
- int n_batch;
- int n_examples;
-
float f_norm_rms_eps;
float rope_freq_base;
float rope_freq_scale;
-
- int print_info_interval;
-
- bool samples_start_after_nl;
- bool use_adam;
- bool use_flash;
- bool use_checkpointing;
- bool use_alloc;
-
- // only adam
- int warmup;
- int cos_decay_steps;
- float cos_decay_restart;
- float cos_decay_min;
- bool enable_restart;
-
- int opt_past;
- float opt_delta;
- int opt_max_no_improvement;
-
- int lbfgs_n_iter;
- int adam_n_iter;
- float adam_alpha;
- float adam_min_alpha;
- float adam_decay;
- int adam_decay_min_ndim;
- float adam_beta1;
- float adam_beta2;
- float adam_gclip;
- float adam_eps_f;
-
- int mem_model_gb;
- int mem_compute_gb;
- int mem_compute0_gb;
};
struct train_params get_default_train_params() {
struct train_params params;
+ params.common = get_default_train_params_common();
params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin";
- params.fn_train_data = "shakespeare.txt";
- params.fn_checkpoint_in = "checkpoint.bin";
- params.fn_checkpoint_out = "checkpoint.bin";
params.fn_model_out = "ggml-checkpoint-f32.bin";
- params.seed = -1;
+ params.only_write_model = false;
params.n_ctx = 128;
params.n_embd = 256;
params.n_layer = 16;
params.n_ff = 768;
- params.n_threads = 6;
- params.n_batch = 8;
- params.n_examples = 1;
-
- params.f_norm_rms_eps = 1e-5;
+ params.f_norm_rms_eps = 1e-5f;
params.rope_freq_base = 10000.0f;
params.rope_freq_scale = 1.0f;
- params.print_info_interval = 1;
-
- params.samples_start_after_nl = false;
- params.use_adam = true;
- params.use_flash = true;
- params.use_checkpointing = true;
- params.use_alloc = true;
-
- params.opt_past = 0;
- params.opt_delta = 1e-5f;
- params.opt_max_no_improvement = 0;
-
- // only adam
- params.warmup = 100;
- params.cos_decay_steps = 1000;
- params.cos_decay_restart = 1.1f;
- params.cos_decay_min = 0.1f;
- params.enable_restart = false;
-
- params.lbfgs_n_iter = 256;
- params.adam_n_iter = 256;
- params.adam_alpha = 1e-3f;
- params.adam_min_alpha = 0;
- params.adam_decay = 1e-1f;
- params.adam_decay_min_ndim = 2;
- params.adam_beta1 = 0.9f;
- params.adam_beta2 = 0.999f;
- params.adam_gclip = 1.0f;
- params.adam_eps_f = 0.0f;
-
- params.mem_model_gb = 2;
- params.mem_compute_gb = 24;
- params.mem_compute0_gb = 8;
return params;
}
-void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) {
+static void train_print_usage(int argc, char ** argv, const struct train_params * params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
+
fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model);
- fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
- fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
- fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out);
- fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
- fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
+ fprintf(stderr, " --only-write-model only save llama model, don't do any training. use this if you only want to convert a checkpoint to a model.\n");
fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff);
fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head);
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
- fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
- fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
- fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
- fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
- fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
- fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
- fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
- fprintf(stderr, " --no-flash Don't use flash attention \n");
- fprintf(stderr, " --use-flash Use flash attention (default)\n");
- fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
- fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
- fprintf(stderr, " --no-alloc Don't use allocator\n");
- fprintf(stderr, " --use-alloc Use allocator (default)\n");
- fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
- fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
- fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
- fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
- fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
- fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
- fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
- fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
- fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
- fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
- fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
- fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
- fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
- fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
- fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
- fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
- fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
- fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
- fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
- fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
- fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
- fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb);
- fprintf(stderr, "\n");
+
+ print_common_train_usage(argc, argv, ¶ms->common);
}
-bool train_params_parse(int argc, char ** argv, struct train_params * params) {
+static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
bool invalid_param = false;
std::string arg;
struct train_params default_params = get_default_train_params();
std::replace(arg.begin(), arg.end(), '_', '-');
}
- if (arg == "--vocab-model") {
- if (++i >= argc) {
- invalid_param = true;
+ if (consume_common_train_arg(argc, argv, &i, ¶ms->common, &invalid_param)) {
+ if (invalid_param) {
break;
+ } else if (params->common.print_usage) {
+ train_print_usage(argc, argv, &default_params);
+ exit(0);
}
- params->fn_vocab_model = argv[i];
- } else if (arg == "--train-data") {
+ } else if (arg == "--vocab-model") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params->fn_train_data = argv[i];
- } else if (arg == "--checkpoint-in") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->fn_checkpoint_in = argv[i];
- } else if (arg == "--checkpoint-out") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->fn_checkpoint_out = argv[i];
+ params->fn_vocab_model = argv[i];
} else if (arg == "--model-out") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_model_out = argv[i];
- } else if (arg == "-s" || arg == "--seed") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->seed = std::stoi(argv[i]);
- } else if (arg == "-c" || arg == "--ctx") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->n_ctx = std::stoi(argv[i]);
+ } else if (arg == "--only-write-model") {
+ params->only_write_model = true;
} else if (arg == "--embd") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->rope_freq_scale = std::stof(argv[i]);
- } else if (arg == "-t" || arg == "--threads") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->n_threads = std::stoi(argv[i]);
- } else if (arg == "-b" || arg == "--batch") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->n_batch = std::stoi(argv[i]);
- } else if (arg == "-n" || arg == "--examples") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->n_examples = std::stoi(argv[i]);
- } else if (arg == "--print-info-interval") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->print_info_interval = std::stoi(argv[i]);
- } else if (arg == "--samples-after-nl") {
- params->samples_start_after_nl = true;
- } else if (arg == "--use-lbfgs") {
- params->use_adam = false;
- } else if (arg == "--use-adam") {
- params->use_adam = true;
- } else if (arg == "--no-flash") {
- params->use_flash = false;
- } else if (arg == "--use-flash") {
- params->use_flash = true;
- } else if (arg == "--no-checkpointing") {
- params->use_checkpointing = false;
- } else if (arg == "--use-checkpointing") {
- params->use_checkpointing = true;
- } else if (arg == "--no-alloc") {
- params->use_alloc = false;
- } else if (arg == "--use-alloc") {
- params->use_alloc = true;
- } else if (arg == "--warmup") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->warmup = std::stoi(argv[i]);
- } else if (arg == "--cos-decay-steps") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->cos_decay_steps = std::stof(argv[i]);
- } else if (arg == "--cos-decay-restart") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->cos_decay_restart = std::stof(argv[i]);
- } else if (arg == "--cos-decay-min") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->cos_decay_min = std::stof(argv[i]);
- } else if (arg == "--enable-restart") {
- params->enable_restart = true;
- } else if (arg == "--disable-restart") {
- params->enable_restart = false;
- } else if (arg == "--opt-past") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->opt_past = std::stoi(argv[i]);
- } else if (arg == "--opt-delta") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->opt_delta = std::stof(argv[i]);
- } else if (arg == "--opt-max-no-improvement") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->opt_max_no_improvement = std::stoi(argv[i]);
- } else if (arg == "--adam-epsf") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_eps_f = std::stof(argv[i]);
- } else if (arg == "--adam-iter") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_n_iter = std::stoi(argv[i]);
- } else if (arg == "--adam-alpha") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_alpha = std::stof(argv[i]);
- } else if (arg == "--adam-min-alpha") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_min_alpha = std::stof(argv[i]);
- } else if (arg == "--adam-decay") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_decay = std::stof(argv[i]);
- } else if (arg == "--adam-decay-min-ndim") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_decay_min_ndim = std::stoi(argv[i]);
- } else if (arg == "--adam-beta1") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_beta1 = std::stof(argv[i]);
- } else if (arg == "--adam-beta2") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_beta2 = std::stof(argv[i]);
- } else if (arg == "--adam-gclip") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->adam_gclip = std::stof(argv[i]);
- } else if (arg == "--lbfgs-iter") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->lbfgs_n_iter = std::stoi(argv[i]);
- } else if (arg == "--mem-model") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->mem_model_gb = std::stoi(argv[i]);
- } else if (arg == "--mem-compute") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->mem_compute_gb = std::stoi(argv[i]);
- } else if (arg == "--mem-compute0") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params->mem_compute0_gb = std::stoi(argv[i]);
- } else if (arg == "-h" || arg == "--help") {
- train_print_usage(argc, argv, &default_params);
- exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
train_print_usage(argc, argv, &default_params);
train_print_usage(argc, argv, &default_params);
exit(1);
}
+ finish_processing_train_args(¶ms->common);
return true;
}
-struct opt_callback_data {
- struct train_params * params;
- struct ggml_opt_context * opt;
- struct llama_context * lctx;
- llama_token * tokens_data;
- size_t tokens_size;
- int * samples_data;
- size_t samples_size;
- int shuffle_countdown;
- struct ggml_tensor * tokens_input;
- struct ggml_tensor * target_logits;
- struct ggml_tensor * target_probs;
+struct save_train_files_data {
+ const char * fn_checkpoint_out;
+ const char * fn_model_out;
+ const char * fn_vocab_model;
+ const char * pattern_fn_it;
+ const char * fn_latest;
+ struct my_llama_model * model;
};
-void opt_callback(void * vdata, float * sched) {
- struct opt_callback_data * data = (struct opt_callback_data *) vdata;
- struct train_params * params = data->params;
- struct ggml_opt_context * opt = data->opt;
- int n_batch = params->n_batch;
-
- *sched = (opt->iter < params->warmup)
- ? (float) opt->iter / (float) params->warmup
- : cosine_decay_restart(
- params->cos_decay_steps,
- params->cos_decay_min,
- opt->iter - params->warmup,
- params->cos_decay_restart,
- params->enable_restart);
- float min_sched = params->adam_min_alpha / params->adam_alpha;
- *sched = min_sched + *sched * (1.0f - min_sched);
-
- int impr_plot = std::isnan(opt->loss_after) ? 0 : -std::lround(1 + (opt->loss_before - opt->loss_after) * 10.0f);
- printf("%s: iter=%*d, sched=%f loss0=%f loss=%f | improvement: %*d>\n", __func__, 6, opt->iter, *sched, opt->loss_before, opt->loss_after, impr_plot, (int)0);
-
- if (data->shuffle_countdown < n_batch) {
- printf("%s: reshuffle samples\n", __func__);
- shuffle_ints(data->samples_data, data->samples_data + data->samples_size);
- for (int i = 0; i < (int) data->samples_size; ++i) {
- GGML_ASSERT(data->samples_data[i]+params->n_ctx-1 < (int) data->tokens_size);
- }
- data->shuffle_countdown = data->samples_size;
+static void save_train_files(void * vdata, struct train_state * train) {
+ struct save_train_files_data * data = (struct save_train_files_data *) vdata;
+ int64_t iter = train->opt->iter;
+
+ if (strlen(data->fn_checkpoint_out) > 0) {
+ save_checkpoint_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->fn_vocab_model, data->model, train);
+ save_checkpoint_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->fn_vocab_model, data->model, train);
+
+ }
+ if (strlen(data->fn_model_out) > 0) {
+ save_llama_model_file(get_train_filename(data->fn_model_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->fn_vocab_model, data->model);
+ save_llama_model_file(get_train_filename(data->fn_model_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->fn_vocab_model, data->model);
}
+}
+
+static int64_t get_parameter_count(struct my_llama_model* model) {
+ int64_t nx = 0;
+ nx += ggml_nelements(model->tok_embeddings);
+ nx += ggml_nelements(model->norm);
+ nx += ggml_nelements(model->output);
- get_example_targets_batch(
- data->lctx,
- data->samples_data,
- data->samples_size,
- data->tokens_data,
- data->tokens_size,
- opt->iter,
- data->tokens_input,
- data->target_logits,
- data->target_probs);
-
- data->shuffle_countdown -= n_batch;
+ for (uint32_t i = 0; i < model->layers.size(); ++i) {
+ auto & layer = model->layers[i];
+ nx += ggml_nelements(layer.attention_norm);
+ nx += ggml_nelements(layer.wq);
+ nx += ggml_nelements(layer.wk);
+ nx += ggml_nelements(layer.wv);
+ nx += ggml_nelements(layer.wo);
+ nx += ggml_nelements(layer.ffn_norm);
+ nx += ggml_nelements(layer.w1);
+ nx += ggml_nelements(layer.w2);
+ nx += ggml_nelements(layer.w3);
+ }
+ return nx;
}
int main(int argc, char ** argv) {
return 1;
}
- if (params.seed == LLAMA_DEFAULT_SEED) {
- params.seed = time(NULL);
+ if (params.common.seed == LLAMA_DEFAULT_SEED) {
+ params.common.seed = time(NULL);
}
- printf("%s: seed: %u\n", __func__, params.seed);
- srand(params.seed);
+ printf("%s: seed: %u\n", __func__, params.common.seed);
+ srand(params.common.seed);
struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = true;
struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params);
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
- printf("%s: tokenize training data\n", __func__);
- std::vector<llama_token> train_tokens;
- if (tokenize_file(lctx, params.fn_train_data, train_tokens) < 0) {
- fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, params.fn_train_data);
- }
- printf("%s: number of training tokens: %d\n", __func__, (int) train_tokens.size());
-
struct my_llama_model model;
model.hparams.n_vocab = llama_n_vocab(lctx);
- model.hparams.n_ctx = params.n_ctx;
+ model.hparams.n_ctx = params.common.n_ctx;
model.hparams.n_embd = params.n_embd;
model.hparams.n_head = params.n_head;
model.hparams.n_layer = params.n_layer;
model.hparams.rope_freq_base = params.rope_freq_base;
model.hparams.rope_freq_scale = params.rope_freq_scale;
- print_params(&model.hparams);
-
- std::vector<size_t> token_noccurs;
- std::vector<bool> token_notavail;
- token_noccurs.resize(model.hparams.n_vocab, 0);
- token_notavail.resize(model.hparams.n_vocab, true);
- for (int i = 0; i < (int) train_tokens.size(); ++i) {
- ++token_noccurs[train_tokens[i]];
- token_notavail[train_tokens[i]] = false;
- }
-
- std::vector<float> token_freq;
- token_freq.resize(model.hparams.n_vocab, 0);
- int n_unique_tokens = 0;
- for (int i = 0; i < (int) token_noccurs.size(); ++i) {
- token_freq[i] = (float) token_noccurs[i] / (float) train_tokens.size();
- n_unique_tokens += (token_noccurs[i] > 0) ? 1 : 0;
- }
- printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
+ struct train_state * train = init_train_state();
+ struct ggml_opt_context * opt = train->opt;
+
+ // set opt params from command line
+ opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
+ opt->params.print_forward_graph = false;
+ opt->params.print_backward_graph = false;
+ opt->params.n_threads = params.common.n_threads;
+ opt->params.past = params.common.opt_past;
+ opt->params.delta = params.common.opt_delta;
+ opt->params.max_no_improvement = params.common.opt_max_no_improvement;
+ opt->params.n_gradient_accumulation = params.common.n_gradient_accumulation;
+ opt->params.adam.n_iter = params.common.adam_n_iter;
+ opt->params.adam.sched = 1.0f;
+ opt->params.adam.alpha = params.common.adam_alpha;
+ opt->params.adam.decay = params.common.adam_decay;
+ opt->params.adam.decay_min_ndim = params.common.adam_decay_min_ndim;
+ opt->params.adam.beta1 = params.common.adam_beta1;
+ opt->params.adam.beta2 = params.common.adam_beta2;
+ opt->params.adam.gclip = params.common.adam_gclip;
+ opt->params.adam.eps_f = params.common.adam_eps_f;
- struct ggml_init_params lcparams;
- lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb);
- lcparams.mem_buffer = NULL;
- lcparams.no_alloc = false;
+ printf("%s: init model\n", __func__);
+ bool existed = load_checkpoint_file(params.common.fn_checkpoint_in, &model, train);
+ if (existed) {
+ // overwrite last n_ctx with user provided n_ctx
+ if (params.common.custom_n_ctx) {
+ model.hparams.n_ctx = params.common.n_ctx;
+ }
- model.ctx = ggml_init(lcparams);
+ const bool opt_past_changed = opt->params.past != params.common.opt_past;
- int n_tokens = model.hparams.n_ctx;
- int n_vocab = model.hparams.n_vocab;
- int n_batch = params.n_batch;
-
- struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
- memset(opt, 0, sizeof(struct ggml_opt_context));
-
- struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
- struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
- opt_params_adam.print_forward_graph = false;
- opt_params_adam.print_backward_graph = false;
- opt_params_adam.n_threads = params.n_threads;
- opt_params_adam.past = params.opt_past;
- opt_params_adam.delta = params.opt_delta;
- opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
- opt_params_adam.adam.n_iter = params.adam_n_iter;
- opt_params_adam.adam.sched = 1.0f;
- opt_params_adam.adam.alpha = params.adam_alpha;
- opt_params_adam.adam.decay = params.adam_decay;
- opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim;
- opt_params_adam.adam.beta1 = params.adam_beta1;
- opt_params_adam.adam.beta2 = params.adam_beta2;
- opt_params_adam.adam.gclip = params.adam_gclip;
- opt_params_adam.adam.eps_f = params.adam_eps_f;
-
- opt_params_lbfgs.print_forward_graph = false;
- opt_params_lbfgs.print_backward_graph = false;
- opt_params_lbfgs.n_threads = params.n_threads;
- opt_params_adam.past = params.opt_past;
- opt_params_adam.delta = params.opt_delta;
- opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
- opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
-
- opt->ctx = model.ctx;
- opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
-
- printf("%s: init model\n", __func__);
- bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, opt);
- if (!existed) {
+ if (opt_past_changed) {
+ die("Optimizer parameter '--opt-past N' differs from checkpoint file. To use different value train from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting");
+ // need to discard previous optimizer past function value statistics and opt_init with new shapes
+ // TODO
+ }
+ } else {
init_model(&model);
+ randomize_model(&model, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f);
+ if (!params.only_write_model) {
+ ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&model));
+ }
}
- set_param_model(&model);
-
- opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
+ opt->iter = train->train_its;
- opt->iter = model.train_its;
- printf("%s: opt iter %d\n", __func__, opt->iter);
-
- bool from_scratch = !existed;
- if (from_scratch) {
- randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f);
+ print_params(&model.hparams);
+ printf("%s: total train_iterations %llu\n", __func__, (long long unsigned) train->train_its);
+ printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples);
+ printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens);
+ printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
+ printf("%s: model_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(model.ctx) + model.data.size()), (float) (ggml_used_mem(model.ctx) + model.data.size()) / (1024.0f*1024.0f));
+
+ if (params.only_write_model) {
+ save_train_files_data save_data;
+ save_data.fn_checkpoint_out = "";
+ save_data.fn_model_out = params.fn_model_out;
+ save_data.fn_vocab_model = params.fn_vocab_model;
+ save_data.pattern_fn_it = params.common.pattern_fn_it;
+ save_data.fn_latest = params.common.fn_latest;
+ save_data.model = &model;
+
+ save_train_files(&save_data, train);
+
+ free_train_state(train);
+ ggml_free(model.ctx);
+ llama_free(lctx);
+ llama_free_model(lmodel);
+ return 0;
}
- printf("used_mem model: %zu bytes\n", ggml_used_mem(model.ctx));
- // ggml_print_tensor_objects(model.ctx);
+ printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f));
+ printf("%s: opt iter %d\n", __func__, opt->iter);
- // TODO: use std::vector<uint8_t> intead of "new"
- size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb);
- uint8_t * compute_addr = new uint8_t[compute_size];
+ int n_tokens = model.hparams.n_ctx;
+ int n_vocab = model.hparams.n_vocab;
+ int n_batch = params.common.n_batch;
- size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb);
- uint8_t * compute_buf_0 = new uint8_t[size_buf_0];
+ std::vector<uint8_t> mem_input_data;
+ std::vector<uint8_t> mem_compute_data;
ggml_allocr * alloc = NULL;
- if (params.use_alloc) {
- static const size_t tensor_alignment = 32;
- alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment);
- }
-
- GGML_ASSERT(n_tokens < (int) train_tokens.size());
- std::vector<int> train_samples;
- train_samples.push_back(0);
- for (int i = 1; i < (int) train_tokens.size() - n_tokens; ++i) {
- if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl(lctx))) {
- train_samples.push_back(i);
- }
- }
- shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size());
- for (int i = 0; i < (int) train_samples.size(); ++i) {
- GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size());
- }
-
- printf("%s: begin training\n", __func__);
-
- struct opt_callback_data opt_cb_data;
- opt_cb_data.params = ¶ms;
- opt_cb_data.opt = opt;
- opt_cb_data.lctx = lctx;
- opt_cb_data.tokens_data = train_tokens.data();
- opt_cb_data.tokens_size = train_tokens.size();
- opt_cb_data.samples_data = train_samples.data();
- opt_cb_data.samples_size = train_samples.size();
- opt_cb_data.shuffle_countdown = train_samples.size();
- opt_cb_data.tokens_input = NULL;
- opt_cb_data.target_logits = NULL;
- opt_cb_data.target_probs = NULL;
-
- int64_t t0 = ggml_time_ms();
-
- for (int ex = 0; ex < params.n_examples; ++ex) {
- if (ex*n_batch >= (int) train_samples.size()) {
- shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size());
- for (int i = 0; i < (int) train_samples.size(); ++i) {
- GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size());
- }
- }
-
- struct ggml_init_params cparams = {
- compute_size, // mem_size
- compute_addr, // mem_buffer
- false, // no_alloc
- };
- struct ggml_context * ctx0 = ggml_init(cparams);
-
- ggml_set_no_alloc(ctx0, false);
-
- // don't use alloc for input tensors, so we can safely fill them with data
- //struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
- //struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
- struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
- struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
- struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
-
- ggml_set_no_alloc(ctx0, (alloc != NULL));
- if (alloc) {
- ggml_allocr_reset(alloc);
- }
-
- opt_cb_data.tokens_input = tokens_input;
- opt_cb_data.target_logits = target_logits;
- opt_cb_data.target_probs = target_probs;
-
- int n_past = 0;
-
- struct ggml_cgraph * gf = ggml_new_graph(ctx0);
- struct ggml_cgraph * gb = ggml_new_graph(ctx0);
- struct ggml_cgraph * gb_tmp = params.use_checkpointing
- ? ggml_new_graph(ctx0)
+ // context for input tensors without their data
+ struct ggml_init_params ctx_input_params = {
+ ggml_tensor_overhead() * 2, // mem_size
+ NULL, // mem_buffer
+ true, // no_alloc
+ };
+ struct ggml_context * ctx_input = ggml_init(ctx_input_params);
+
+ // the input tensors
+ struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx_input, GGML_TYPE_I32, n_tokens, n_batch);
+ struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
+
+ // measure required memory for input tensors
+ alloc = ggml_allocr_new_measure(tensor_alignment);
+ ggml_allocr_alloc(alloc, tokens_input);
+ ggml_allocr_alloc(alloc, target_probs);
+ size_t max_input_size = ggml_allocr_max_size(alloc) + tensor_alignment;
+ ggml_allocr_free(alloc);
+ printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
+
+ // allocate input tensors
+ mem_input_data.resize(max_input_size);
+ alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
+ ggml_allocr_alloc(alloc, tokens_input);
+ ggml_allocr_alloc(alloc, target_probs);
+ ggml_allocr_free(alloc);
+
+ // context for compute tensors without their data
+ size_t estimated_compute_size_wo_data = (
+ ggml_tensor_overhead()*GGML_MAX_NODES*2
+ + (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
+ params.common.use_checkpointing ? 3 : 2
+ )
+ );
+ struct ggml_init_params ctx_compute_params = {
+ estimated_compute_size_wo_data, // mem_size
+ NULL, // mem_buffer
+ true, // no_alloc
+ };
+ struct ggml_context * ctx_compute = NULL;
+
+ struct ggml_tensor * loss = NULL;
+ struct ggml_tensor * logits = NULL;
+
+ struct ggml_cgraph * gf = NULL;
+ struct ggml_cgraph * gb = NULL;
+ struct ggml_cgraph * gb_tmp = NULL;
+
+ // measure required memory for compute tensors
+ size_t best_compute_size = SIZE_MAX;
+ enum ggml_cgraph_eval_order best_order = GGML_CGRAPH_EVAL_ORDER_COUNT;
+ // find best evaluation order
+ for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
+ ctx_compute = ggml_init(ctx_compute_params);
+ alloc = ggml_allocr_new_measure(tensor_alignment);
+ gf = ggml_new_graph(ctx_compute);
+ gf->order = (enum ggml_cgraph_eval_order) order;
+ gb = ggml_new_graph(ctx_compute);
+ gb_tmp = params.common.use_checkpointing
+ ? ggml_new_graph(ctx_compute)
: NULL;
-
- GGML_ASSERT(n_past == 0);
-
- struct ggml_tensor * loss = NULL;
- struct ggml_tensor * logits = NULL;
-
loss = llama_build_train_graphs(
- &model, alloc, ctx0,
+ &model, alloc, ctx_compute,
gf, gb, gb_tmp,
&logits, tokens_input, target_probs,
n_tokens, n_batch,
- params.use_flash,
- params.use_checkpointing
+ params.common.use_flash,
+ params.common.use_checkpointing
);
+ size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
+ if (max_compute_size < best_compute_size) {
+ best_compute_size = max_compute_size;
+ best_order = gf->order;
+ }
+ ggml_allocr_free(alloc);
+ ggml_free(ctx_compute);
+ }
+ size_t max_compute_size = best_compute_size;
+ printf("%s: compute_size = %zu bytes (%.1f MB)\n", __func__, max_compute_size, (float) max_compute_size / (1024.0f*1024.0f));
+ printf("%s: evaluation order = %s\n", __func__,
+ (best_order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? "LEFT_TO_RIGHT" :
+ (best_order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? "RIGHT_TO_LEFT" :
+ "invalid");
+
+ // allocate compute tensors
+ mem_compute_data.resize(max_compute_size);
+ ctx_compute = ggml_init(ctx_compute_params);
+ alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
+ gf = ggml_new_graph(ctx_compute);
+ gf->order = best_order;
+ gb = ggml_new_graph(ctx_compute);
+ gb_tmp = params.common.use_checkpointing
+ ? ggml_new_graph(ctx_compute)
+ : NULL;
+ loss = llama_build_train_graphs(
+ &model, alloc, ctx_compute,
+ gf, gb, gb_tmp,
+ &logits, tokens_input, target_probs,
+ n_tokens, n_batch,
+ params.common.use_flash,
+ params.common.use_checkpointing
+ );
+ ggml_allocr_free(alloc);
- size_t used_mem_before_opt = ggml_used_mem(ctx0);
-
- opt->params.adam.sched = (opt->iter < params.warmup)
- ? (float) opt->iter / (float) params.warmup
- : cosine_decay_restart(
- params.cos_decay_steps,
- params.cos_decay_min,
- opt->iter - params.warmup,
- params.cos_decay_restart,
- params.enable_restart);
-
- float min_sched = params.adam_min_alpha / params.adam_alpha;
- opt->params.adam.sched = min_sched + opt->params.adam.sched * (1.0f - min_sched);
-
- printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
-
- ggml_opt_resume_g(ctx0, opt, loss, gf, gb, &opt_callback, (void *) &opt_cb_data);
+ std::vector<llama_token> train_tokens;
+ std::vector<size_t> train_samples_begin;
+ std::vector<size_t> train_samples_size;
+ printf("%s: tokenize training data\n", __func__);
+ tokenize_file(lctx,
+ params.common.fn_train_data,
+ params.common.sample_start,
+ params.common.include_sample_start,
+ params.common.overlapping_samples,
+ n_tokens,
+ train_tokens,
+ train_samples_begin,
+ train_samples_size);
+ GGML_ASSERT(train_samples_begin.size() == train_samples_size.size());
+
+ printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size());
+
+ size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
+ const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
+ if (changed_train_data) {
+ printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
+ }
+ if (params.common.force_reshuffle) {
+ printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
+ }
+ if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) {
+ train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed);
+ train->shuffle_sample_count = train_samples_size.size();
+ train->shuffle_next_sample = 0;
+ train->shuffle_samples_hash = shuffle_samples_hash;
+ }
+ std::vector<size_t> train_shuffled_samples_offs;
+ std::vector<size_t> train_shuffled_samples_begin;
+ std::vector<size_t> train_shuffled_samples_size;
+ train_shuffled_samples_offs.resize(train_samples_begin.size());
+ train_shuffled_samples_begin.resize(train_samples_begin.size());
+ train_shuffled_samples_size.resize(train_samples_size.size());
+ train->shuffle_rng_state_next = shuffle_samples(
+ train->shuffle_rng_state_current,
+ train_shuffled_samples_offs.data(),
+ train_shuffled_samples_begin.data(),
+ train_shuffled_samples_size.data(),
+ train_samples_begin.data(),
+ train_samples_size.data(),
+ train_samples_size.size());
+ printf("%s: begin training\n", __func__);
- size_t used_mem_after_opt = ggml_used_mem(ctx0);
+ save_train_files_data save_data;
+ save_data.fn_checkpoint_out = params.common.fn_checkpoint_out;
+ save_data.fn_model_out = params.fn_model_out;
+ save_data.fn_vocab_model = params.fn_vocab_model;
+ save_data.pattern_fn_it = params.common.pattern_fn_it;
+ save_data.fn_latest = params.common.fn_latest;
+ save_data.model = &model;
+
+ struct train_opt_callback_data opt_cb_data;
+ opt_cb_data.params = ¶ms.common;
+ opt_cb_data.train = train;
+ opt_cb_data.save_cb = &save_train_files;
+ opt_cb_data.save_data = &save_data;
+ opt_cb_data.lctx = lctx;
+ opt_cb_data.last_save_iter = opt->iter;
+ opt_cb_data.tokens_data = train_tokens.data();
+ opt_cb_data.tokens_size = train_tokens.size();
+ opt_cb_data.samples_begin = train_samples_begin.data();
+ opt_cb_data.samples_size = train_samples_size.data();
+ opt_cb_data.shuffled_samples_offs = train_shuffled_samples_offs.data();
+ opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data();
+ opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data();
+ opt_cb_data.samples_count = train_samples_size.size();
+ opt_cb_data.tokens_input = tokens_input;
+ opt_cb_data.target_probs = target_probs;
+ opt_cb_data.first_iter = opt->iter;
+ opt_cb_data.first_epoch = train->train_epochs;
+ opt_cb_data.iter_at_last_epoch = -1;
+ opt_cb_data.last_time = ggml_time_ms();
+ opt_cb_data.millis_per_iter = 0.0;
+
+ // measure required memory for work buffer
+ size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE;
+ printf("%s: work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));
+
+ // context for work buffer
+ struct ggml_init_params ctx_work_params = {
+ max_work_size, // mem_size
+ NULL, // mem_buffer
+ false, // no_alloc
+ };
+ struct ggml_context * ctx_work = ggml_init(ctx_work_params);
- int n_iter = params.use_adam ? params.adam_n_iter : params.lbfgs_n_iter;
- model.train_its = opt->iter;
- model.train_samples += n_batch * n_iter;
- model.train_tokens += n_batch * n_tokens * n_iter;
+ int64_t t0 = ggml_time_ms();
- if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
- printf("Example %d, opt iter %d\n", ex, opt->iter);
- printf("error_before_opt: %.6f\n", opt->loss_before);
- printf("error_after_opt: %.6f\n", opt->loss_after);
- printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt);
- printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
- }
+ ggml_opt_resume_g(ctx_work, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data);
- ggml_free(ctx0);
- }
+ ggml_free(ctx_work);
+ ggml_free(ctx_compute);
+ ggml_free(ctx_input);
int64_t t1 = ggml_time_ms();
- int64_t d = t1-t0;
- double dd = (double) d * 1e-3;
- printf("%s: total training time=%f seconds\n", __func__, dd);
+ printf("%s: total training time: ", __func__);
+ print_duration((double) (t1 - t0));
+ printf("\n");
- if (params.n_examples > 0) {
- save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt);
- }
+ int new_iters = opt->iter - opt_cb_data.last_save_iter;
+ if (new_iters > 0) {
+ train->train_its += new_iters;
+ train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
- if (strlen(params.fn_model_out) > 0) {
- save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model);
+ save_train_files(&save_data, train);
+ opt_cb_data.last_save_iter = opt->iter;
}
if (alloc) {
ggml_allocr_free(alloc);
}
- delete[] compute_addr;
- delete[] compute_buf_0;
+ ggml_free(opt->ctx);
+ free_train_state(train);
ggml_free(model.ctx);
llama_free(lctx);
llama_free_model(lmodel);
size_t size;
};
-#define MAX_FREE_BLOCKS 128
+#define MAX_FREE_BLOCKS 256
struct ggml_allocr {
void * data;
}
tensor->data = addr;
+ AT_PRINTF("%s: allocated data at %p\n", __func__, tensor->data);
#ifdef GGML_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, tensor);
size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
size = aligned_offset(NULL, size, alloc->alignment);
- AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
+ AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
+ AT_PRINTF("%s: alloc->data = %p alloc->data+alloc->size = %p alloc->data+alloc->max_size = %p\n", __func__, alloc->data, (char*)alloc->data + alloc->size, (char*)alloc->data + alloc->max_size);
#ifdef GGML_ALLOCATOR_DEBUG
remove_allocated_tensor(alloc, tensor);
size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
}
+
+size_t ggml_allocr_max_size(struct ggml_allocr * alloc) {
+ return alloc->max_size;
+}
GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc);
GGML_API void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor);
GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
+GGML_API size_t ggml_allocr_max_size(struct ggml_allocr * alloc);
#ifdef __cplusplus
#define GGML_SOFT_MAX_UNROLL 4
#define GGML_VEC_DOT_UNROLL 2
+#define GGML_VEC_MAD_UNROLL 32
//
// logging
#endif
}
+// xs and vs are byte strides of x and v
+inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
+
+ const float * restrict x[GGML_VEC_MAD_UNROLL];
+ const float * restrict v[GGML_VEC_MAD_UNROLL];
+
+ for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
+ x[i] = (const float *) ((const char *) xv + i*xs);
+ v[i] = (const float *) ((const char *) vv + i*vs);
+ }
+
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F32_STEP - 1));
+
+ GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
+
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+ vx[k] = GGML_F32_VEC_SET1(v[k][0]);
+ }
+
+ GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
+ GGML_F32_VEC ay[GGML_F32_ARR];
+
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
+ for (int j = 0; j < GGML_F32_ARR; j++) {
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+ ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
+ }
+
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+ }
+ }
+
+ // leftovers
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+ for (int i = np; i < n; ++i) {
+ y[i] += x[k][i]*v[k][0];
+ }
+ }
+#else
+ // scalar
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+ for (int i = 0; i < n; ++i) {
+ y[i] += x[k][i]*v[k][0];
+ }
+ }
+#endif
+}
+
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#if defined(GGML_USE_ACCELERATE)
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
- return
- (t0->ne[1] == t1->ne[1]) &&
- (t0->ne[2] == t1->ne[2]) &&
- (t0->ne[3] == t1->ne[3]);
+ return (t0->ne[1] == t1->ne[1]) &&
+ (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
+ (t1->ne[3]%t0->ne[3] == 0);
}
enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
return tensor;
}
+void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
+ const int64_t ne2 = tensor->ne[2];
+ const int64_t ne1 = tensor->ne[1];
+ const int64_t ne0 = tensor->ne[0];
+
+ const int64_t i3_ = (i/(ne2*ne1*ne0));
+ const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
+ const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
+ const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
+
+ if (i0) {
+ * i0 = i0_;
+ }
+ if (i1) {
+ * i1 = i1_;
+ }
+ if (i2) {
+ * i2 = i2_;
+ }
+ if (i3) {
+ * i3 = i3_;
+ }
+}
+
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
+ if (!ggml_is_contiguous(tensor)) {
+ int64_t id[4] = { 0, 0, 0, 0 };
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+ return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
+ }
switch (tensor->type) {
case GGML_TYPE_I8:
{
}
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
+ if (!ggml_is_contiguous(tensor)) {
+ int64_t id[4] = { 0, 0, 0, 0 };
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+ ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
+ return;
+ }
switch (tensor->type) {
case GGML_TYPE_I8:
{
}
}
+int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ return ((int8_t *) data)[0];
+ } break;
+ case GGML_TYPE_I16:
+ {
+ return ((int16_t *) data)[0];
+ } break;
+ case GGML_TYPE_I32:
+ {
+ return ((int32_t *) data)[0];
+ } break;
+ case GGML_TYPE_F16:
+ {
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ return ((float *) data)[0];
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ return 0.0f;
+}
+
+void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ ((int8_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_I16:
+ {
+ ((int16_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_I32:
+ {
+ ((int32_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ((float *)(data))[0] = value;
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
+ if (!ggml_is_contiguous(tensor)) {
+ int64_t id[4] = { 0, 0, 0, 0 };
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+ return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
+ }
switch (tensor->type) {
case GGML_TYPE_I8:
{
}
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
+ if (!ggml_is_contiguous(tensor)) {
+ int64_t id[4] = { 0, 0, 0, 0 };
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+ ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
+ return;
+ }
switch (tensor->type) {
case GGML_TYPE_I8:
{
}
}
+float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ return ((int8_t *) data)[0];
+ } break;
+ case GGML_TYPE_I16:
+ {
+ return ((int16_t *) data)[0];
+ } break;
+ case GGML_TYPE_I32:
+ {
+ return ((int32_t *) data)[0];
+ } break;
+ case GGML_TYPE_F16:
+ {
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ return ((float *) data)[0];
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ return 0.0f;
+}
+
+void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ ((int8_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_I16:
+ {
+ ((int16_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_I32:
+ {
+ ((int32_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ((float *)(data))[0] = value;
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
void * ggml_get_data(const struct ggml_tensor * tensor) {
return tensor->data;
}
return ggml_add_impl(ctx, a, b, true);
}
+// ggml_add_cast
+
+static struct ggml_tensor * ggml_add_cast_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_type type) {
+ // TODO: support less-strict constraint
+ // GGML_ASSERT(ggml_can_repeat(b, a));
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
+ GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ // TODO: support backward pass for broadcasting
+ GGML_ASSERT(ggml_are_same_shape(a, b));
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne);
+
+ result->op = GGML_OP_ADD;
+ result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_add_cast(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_type type) {
+ return ggml_add_cast_impl(ctx, a, b, type);
+}
+
// ggml_add1
static struct ggml_tensor * ggml_add1_impl(
result->op = GGML_OP_REPEAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
- result->src[1] = b;
return result;
}
result->op = GGML_OP_REPEAT_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
- result->src[1] = b;
return result;
}
is_node = true;
}
- const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
+ // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
+ const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
result->op = GGML_OP_OUT_PROD;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
struct ggml_tensor * a,
struct ggml_tensor * b) {
GGML_ASSERT(ggml_is_contiguous(a));
- GGML_ASSERT(ggml_is_contiguous(b));
+ // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
bool is_node = false;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
- result->src[2] = c;
return result;
}
// d shape [D,N,ne2,ne3]
// q shape [D,N,ne2,ne3]
- // k shape [D,M,ne2,ne3]
- // v shape [M,D,ne2,ne3]
+ // k shape [D,M,kvne2,ne3]
+ // v shape [M,D,kvne2,ne3]
- const int64_t D = q->ne[0];
- const int64_t N = q->ne[1];
- const int64_t M = k->ne[1];
- const int64_t ne2 = q->ne[2];
- const int64_t ne3 = q->ne[3];
+ const int64_t D = q->ne[0];
+ const int64_t N = q->ne[1];
+ const int64_t M = k->ne[1];
+ const int64_t ne2 = q->ne[2];
+ const int64_t ne3 = q->ne[3];
+ const int64_t kvne2 = k->ne[2];
GGML_ASSERT(k->ne[0] == D);
GGML_ASSERT(v->ne[0] == M);
GGML_ASSERT(v->ne[1] == D);
GGML_ASSERT(d->ne[0] == D);
GGML_ASSERT(d->ne[1] == N);
- GGML_ASSERT(k->ne[2] == ne2);
+ GGML_ASSERT(k->ne[2] == kvne2);
GGML_ASSERT(k->ne[3] == ne3);
- GGML_ASSERT(v->ne[2] == ne2);
+ GGML_ASSERT(v->ne[2] == kvne2);
GGML_ASSERT(v->ne[3] == ne3);
GGML_ASSERT(d->ne[2] == ne2);
GGML_ASSERT(d->ne[3] == ne3);
+ GGML_ASSERT(ne2 % kvne2 == 0);
+
bool is_node = false;
if (q->grad || k->grad || v->grad) {
}
// store gradients of q, k and v as continuous tensors concatenated in result.
- // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
- // gradq->data = result->data
- // gradk->data = result->data + nb0*D*N*ne2*ne3
- // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
- int64_t ne[4] = {D,M+N+M,ne2,ne3};
+ const int64_t elem_q = ggml_nelements(q);
+ const int64_t elem_k = ggml_nelements(k);
+ const int64_t elem_v = ggml_nelements(v);
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+ enum ggml_type result_type = GGML_TYPE_F32;
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
+ const size_t tsize = ggml_type_size(result_type);
+
+ const size_t offs_q = 0;
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+ const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
+
+ const size_t nelements = (end + tsize - 1)/tsize;
+
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
int32_t masked_i = masked ? 1 : 0;
ggml_set_op_params(result, &masked_i, sizeof(masked_i));
const int nth = params->nth;
const enum ggml_type type = src0->type;
+ const enum ggml_type dtype = dst->type;
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
- ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
+ ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ggml_is_quantized(src0->type));
- GGML_ASSERT(dst->type == src0->type);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
// rows per thread
// add src1
ggml_vec_acc_f32(ne00, wdata, src1_row);
// quantize row to dst
- quantize_row_q(wdata, dst_row, ne00);
+ if (quantize_row_q != NULL) {
+ quantize_row_q(wdata, dst_row, ne00);
+ } else {
+ memcpy(dst_row, wdata, ne0*nb0);
+ }
}
}
}
}
+static void ggml_compute_forward_repeat_f16(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(params->ith == 0);
+ GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ GGML_TENSOR_UNARY_OP_LOCALS;
+
+ // guaranteed to be an integer due to the check in ggml_can_repeat
+ const int nr0 = (int)(ne0/ne00);
+ const int nr1 = (int)(ne1/ne01);
+ const int nr2 = (int)(ne2/ne02);
+ const int nr3 = (int)(ne3/ne03);
+
+ // TODO: support for transposed / permuted tensors
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+ // TODO: maybe this is not optimal?
+ for (int i3 = 0; i3 < nr3; i3++) {
+ for (int k3 = 0; k3 < ne03; k3++) {
+ for (int i2 = 0; i2 < nr2; i2++) {
+ for (int k2 = 0; k2 < ne02; k2++) {
+ for (int i1 = 0; i1 < nr1; i1++) {
+ for (int k1 = 0; k1 < ne01; k1++) {
+ for (int i0 = 0; i0 < nr0; i0++) {
+ ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
+ ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
+ // ggml_vec_cpy_f16(ne00, y, x)
+ for (int i = 0; i < ne00; ++i) {
+ y[i] = x[i];
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
static void ggml_compute_forward_repeat(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_repeat_f16(params, src0, dst);
+ } break;
case GGML_TYPE_F32:
{
ggml_compute_forward_repeat_f32(params, src0, dst);
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
- int64_t t0 = ggml_perf_time_us();
- UNUSED(t0);
+ // int64_t t0 = ggml_perf_time_us();
+ // UNUSED(t0);
GGML_TENSOR_BINARY_OP_LOCALS;
return;
}
+ // dst[:,:,:,:] = 0
+ // for i2,i3:
+ // for i1:
+ // for i01:
+ // for i0:
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
// parallelize by last three dimensions
// total rows in dst
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
- // dst[:,:,:,:] = 0
- // for i2,i3:
- // for i1:
- // for i01:
- // for i0:
- // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+ // block-tiling attempt
+ const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
+ const int64_t blck_1 = 16;
- for (int64_t ir = ir0; ir < ir1; ++ir) {
- // dst indices
- const int64_t i3 = ir/(ne2*ne1);
- const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
- const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+ for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
+ const int64_t bir1 = MIN(bir + blck_1, ir1);
+ for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
+ const int64_t bne01 = MIN(bi01 + blck_0, ne01);
+ for (int64_t ir = bir; ir < bir1; ++ir) {
+ // dst indices
+ const int64_t i3 = ir/(ne2*ne1);
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
- const int64_t i02 = i2;
- const int64_t i03 = i3;
+ const int64_t i02 = i2;
+ const int64_t i03 = i3;
- //const int64_t i10 = i1;
- const int64_t i12 = i2;
- const int64_t i13 = i3;
+ //const int64_t i10 = i1;
+ const int64_t i12 = i2;
+ const int64_t i13 = i3;
+
+#if GGML_VEC_MAD_UNROLL > 2
+ const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
+ for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
+ const int64_t i11 = i01;
+
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
+
+ ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
+ }
+ for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
+ const int64_t i11 = i01;
+
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
+
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
+ }
+#else
+ for (int64_t i01 = bi01; i01 < bne01; ++i01) {
+ const int64_t i11 = i01;
+
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
+
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
+ }
+#endif
+ }
+ }
+ }
+
+
+ //int64_t t1 = ggml_perf_time_us();
+ //static int64_t acc = 0;
+ //acc += t1 - t0;
+ //if (t1 - t0 > 10) {
+ // printf("\n");
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
+
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+ //}
+}
+
+static void ggml_compute_forward_out_prod_q_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ // int64_t t0 = ggml_perf_time_us();
+ // UNUSED(t0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const enum ggml_type type = src0->type;
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+
+ GGML_ASSERT(ne02 == ne12);
+ GGML_ASSERT(ne03 == ne13);
+ GGML_ASSERT(ne2 == ne12);
+ GGML_ASSERT(ne3 == ne13);
+
+ // we don't support permuted src0 dim0
+ GGML_ASSERT(nb00 == ggml_type_size(type));
+
+ // dst dim0 cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ // GGML_ASSERT(nb0 <= nb1);
+ // GGML_ASSERT(nb1 <= nb2);
+ // GGML_ASSERT(nb2 <= nb3);
+
+ GGML_ASSERT(ne0 == ne00);
+ GGML_ASSERT(ne1 == ne10);
+ GGML_ASSERT(ne2 == ne02);
+ GGML_ASSERT(ne3 == ne03);
+
+ // nb01 >= nb00 - src0 is not transposed
+ // compute by src0 rows
+
+ // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+
+ if (params->type == GGML_TASK_INIT) {
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // parallelize by last three dimensions
+
+ // total rows in dst
+ const int64_t nr = ne1*ne2*ne3;
+
+ // rows per thread
+ const int64_t dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int64_t ir0 = dr*ith;
+ const int64_t ir1 = MIN(ir0 + dr, nr);
+
+ // dst[:,:,:,:] = 0
+ // for i2,i3:
+ // for i1:
+ // for i01:
+ // for i0:
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+ float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
+
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
+ // dst indices
+ const int64_t i3 = ir/(ne2*ne1);
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ const int64_t i02 = i2;
+ const int64_t i03 = i3;
+
+ //const int64_t i10 = i1;
+ const int64_t i12 = i2;
+ const int64_t i13 = i3;
for (int64_t i01 = 0; i01 < ne01; ++i01) {
const int64_t i11 = i01;
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
- ggml_vec_mad_f32(ne0, d, s0, *s1);
- // for (int64_t i0 = 0; i0 < ne0; ++i0) {
- // d[i0] += s0[i0] * s1[i1];
- // }
+ dequantize_row_q(s0, wdata, ne0);
+ ggml_vec_mad_f32(ne0, d, wdata, *s1);
}
}
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
{
- GGML_ASSERT(false); // todo
- // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
+ ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_F16:
{
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
- const struct ggml_tensor * opt0,
struct ggml_tensor * dst) {
GGML_ASSERT(params->ith == 0);
- GGML_ASSERT(ggml_are_same_shape(opt0, dst));
- GGML_ASSERT(ggml_is_contiguous(opt0));
GGML_ASSERT(ggml_is_contiguous(dst));
- ggml_compute_forward_dup_same_cont(params, opt0, dst);
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+ if (params->type == GGML_TASK_INIT) {
+ memset(dst->data, 0, ggml_nbytes(dst));
+ }
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
- const struct ggml_tensor * opt0,
struct ggml_tensor * dst) {
GGML_ASSERT(params->ith == 0);
- GGML_ASSERT(ggml_are_same_shape(opt0, dst));
- GGML_ASSERT(ggml_is_contiguous(opt0));
GGML_ASSERT(ggml_is_contiguous(dst));
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
- const struct ggml_tensor * opt0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F16:
{
- ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst);
+ ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
- ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst);
+ ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst);
} break;
default:
{
S[i] = -INFINITY;
}
- for (int64_t ic = 0; ic < nek1; ++ic) {
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
// k indices
const int ik3 = iq3;
- const int ik2 = iq2;
+ const int ik2 = iq2 % nek2;
const int ik1 = ic;
// S indices
}
// scale
- ggml_vec_scale_f32(nek1, S, scale);
+ ggml_vec_scale_f32(masked_begin, S, scale);
- if (masked) {
- for (int64_t i = P; i < M; i++) {
- if (i > P + iq1) {
- S[i] = -INFINITY;
- }
- }
+ for (int64_t i = masked_begin; i < M; i++) {
+ S[i] = -INFINITY;
}
// softmax
+ // exclude known -INF S[..] values from max and loop
+ // dont forget to set their SW values to zero
{
float max = -INFINITY;
- ggml_vec_max_f32(M, &max, S);
+ ggml_vec_max_f32(masked_begin, &max, S);
ggml_float sum = 0.0;
{
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
+ if (i >= masked_begin) {
+ break;
+ }
float * SS = S + i;
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
- if (SS[j] == -INFINITY) {
+ if (i + j >= masked_begin) {
+ break;
+ } else if (SS[j] == -INFINITY) {
SS[j] = 0.0f;
} else {
#ifndef GGML_FLASH_ATTN_EXP_FP16
assert(sum > 0.0);
sum = 1.0/sum;
- ggml_vec_scale_f32(M, S, sum);
+ ggml_vec_scale_f32(masked_begin, S, sum);
#ifndef NDEBUG
- for (int i = 0; i < M; ++i) {
+ for (int i = 0; i < masked_begin; ++i) {
assert(!isnan(S[i]));
assert(!isinf(S[i]));
}
const int i2 = iq2;
const int i3 = iq3;
- ggml_vec_dot_f32(nek1,
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
- (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
+ // v indices
+ const int iv2 = iq2 % nev2;
+ const int iv3 = iq3;
+
+ ggml_vec_dot_f32(masked_begin,
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S);
}
}
for (int64_t ic = 0; ic < nek1; ++ic) {
// k indices
const int ik3 = iq3;
- const int ik2 = iq2;
+ const int ik2 = iq2 % nek2;
const int ik1 = ic;
// S indices
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
// k indices
const int ik3 = iq3;
- const int ik2 = iq2;
+ const int ik2 = iq2 % nek2;
const int ik1 = ic;
// S indices
}
// softmax
+ // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
+ // dont forget to set their S values to zero
{
float max = -INFINITY;
ggml_vec_max_f32(M, &max, S);
S16[i] = GGML_FP32_TO_FP16(S[i]);
}
+ // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
for (int64_t ic = 0; ic < nev1; ++ic) {
// dst indices
const int i2 = iq2;
const int i3 = iq3;
- ggml_vec_dot_f16(nek1,
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
+ // v indices
+ const int iv2 = iq2 % nev2;
+ const int iv3 = iq3;
+
+ ggml_vec_dot_f16(nev0,
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S16);
}
} else {
const int i2 = iq2;
const int i3 = iq3;
- ggml_vec_dot_f16_unroll(nek1, nbv1,
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
- ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
+ // v indices
+ const int iv2 = iq2 % nev2;
+ const int iv3 = iq3;
+
+ ggml_vec_dot_f16_unroll(nev0, nbv1,
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S16);
}
}
return;
}
- // parallelize by q rows using ggml_vec_dot_f32
+ const int64_t elem_q = ggml_nelements(q);
+ const int64_t elem_k = ggml_nelements(k);
- // total rows in q
- const int nr = neq2*neq3;
+ enum ggml_type result_type = dst->type;
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
+ const size_t tsize = ggml_type_size(result_type);
+
+ const size_t offs_q = 0;
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+
+ void * grad_q = (char *) dst->data;
+ void * grad_k = (char *) dst->data + offs_k;
+ void * grad_v = (char *) dst->data + offs_v;
+
+ const size_t nbgq1 = nb0*neq0;
+ const size_t nbgq2 = nb0*neq0*neq1;
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
+
+ const size_t nbgk1 = nb0*nek0;
+ const size_t nbgk2 = nb0*nek0*nek1;
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
+
+ const size_t nbgv1 = nb0*nev0;
+ const size_t nbgv2 = nb0*nev0*nev1;
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
+
+ // parallelize by k rows using ggml_vec_dot_f32
+
+ // total rows in k
+ const int nr = nek2*nek3;
// rows per thread
const int dr = (nr + nth - 1)/nth;
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+ // how often k2 (and v2) is repeated in q2
+ int nrep = neq2/nek2;
+
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
- const int iq3 = ir/(neq2);
- const int iq2 = ir - iq3*neq2;
- for ( int iq1 = 0; iq1 < neq1; ++iq1) {
+ const int ik3 = ir/(nek2);
+ const int ik2 = ir - ik3*nek2;
+ const int iq3 = ik3;
+ const int id3 = ik3;
+ const int iv3 = ik3;
+ const int iv2 = ik2;
- // not sure about CACHE_LINE_SIZE_F32..
- // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
- float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
- float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
+ for (int irep = 0; irep < nrep; ++irep) {
+ const int iq2 = ik2 + irep*nek2;
+ const int id2 = iq2;
- for (int i = M; i < Mup; ++i) {
- S[i] = -INFINITY;
- }
+ // (ik2 + irep*nek2) % nek2 == ik2
+ for (int iq1 = 0; iq1 < neq1; ++iq1) {
+ const int id1 = iq1;
- for (int64_t ic = 0; ic < nek1; ++ic) {
- // k indices
- const int ik3 = iq3;
- const int ik2 = iq2;
- const int ik1 = ic;
+ // not sure about CACHE_LINE_SIZE_F32..
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
- // S indices
- const int i1 = ik1;
+ for (int i = M; i < Mup; ++i) {
+ S[i] = -INFINITY;
+ }
- ggml_vec_dot_f32(neq0,
- S + i1,
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
- }
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
+ // k indices
+ const int ik1 = ic;
- // scale
- ggml_vec_scale_f32(nek1, S, scale);
+ // S indices
+ const int i1 = ik1;
- if (masked) {
- for (int64_t i = P; i < M; i++) {
- if (i > P + iq1) {
- S[i] = -INFINITY;
- }
+ ggml_vec_dot_f32(neq0,
+ S + i1,
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
- }
- // softmax
- {
- float max = -INFINITY;
- ggml_vec_max_f32(M, &max, S);
+ // scale
+ ggml_vec_scale_f32(masked_begin, S, scale);
- ggml_float sum = 0.0;
+ for (int64_t i = masked_begin; i < M; i++) {
+ S[i] = -INFINITY;
+ }
+
+ // softmax
+ // exclude known -INF S[..] values from max and loop
+ // dont forget to set their SM values to zero
{
+ float max = -INFINITY;
+ ggml_vec_max_f32(masked_begin, &max, S);
+
+ ggml_float sum = 0.0;
+ {
#ifdef GGML_SOFT_MAX_ACCELERATE
- max = -max;
- vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
- vvexpf(SM, SM, &Mup);
- ggml_vec_sum_f32(Mup, &sum, SM);
+ max = -max;
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
+ vvexpf(SM, SM, &Mup);
+ ggml_vec_sum_f32(Mup, &sum, SM);
#else
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
-
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
- float * SR = S + i;
- float * SW = SM + i;
+ uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
+ ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
- if (SR[j] == -INFINITY) {
- SW[j] = 0.0f;
- } else {
+ for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
+ if (i >= masked_begin) {
+ break;
+ }
+ float * SR = S + i;
+ float * SW = SM + i;
+
+ for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
+ if (i + j >= masked_begin) {
+ break;
+ } else if (SR[j] == -INFINITY) {
+ SW[j] = 0.0f;
+ } else {
#ifndef GGML_FLASH_ATTN_EXP_FP16
- const float val = expf(SR[j] - max);
+ const float val = expf(SR[j] - max);
#else
- ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
- memcpy(&scvt[j], &s, sizeof(uint16_t));
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+ ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
+ memcpy(&scvt[j], &s, sizeof(uint16_t));
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
#endif
- sump[j] += (ggml_float)val;
- SW[j] = val;
+ sump[j] += (ggml_float)val;
+ SW[j] = val;
+ }
}
}
- }
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
- sum += sump[i];
- }
+ for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
+ sum += sump[i];
+ }
#endif
- }
-
- assert(sum > 0.0);
-
- sum = 1.0/sum;
- ggml_vec_scale_f32(M, SM, sum);
-
- }
-
- // step-by-step explanation
- {
- // forward-process shape grads from backward process
- // parallel_for iq2,iq3:
- // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
- // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
- // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
- // for iq1:
- // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
- // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
- // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
- // S0 = -Inf [D,1,1,1]
- // ~S1[i] = dot(kcur[:D,i], qcur)
- // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
- // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
- // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
- // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
- // ~S5[i] = dot(vcur[:,i], S4)
- // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
- // ~dst[i,iq1,iq2,iq3] = S5[i] ^
- // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
- // dst backward-/ grad[dst] = d
- //
- // output gradients with their dependencies:
- //
- // grad[kcur] = grad[S1].T @ qcur
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
- // grad[S4] = grad[S5] @ vcur
- // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
- // grad[qcur] = grad[S1] @ kcur
- // grad[vcur] = grad[S5].T @ S4
- // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
- //
- // in post-order:
- //
- // S1 = qcur @ kcur.T
- // S2 = S1 * scale
- // S3 = diag_mask_inf(S2, P)
- // S4 = softmax(S3)
- // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
- // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
- // grad[S1] = diag_mask_zero(grad[S3], P) * scale
- // grad[qcur] = grad[S1] @ kcur
- // grad[kcur] = grad[S1].T @ qcur
- // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
- //
- // using less variables (SM=S4):
- //
- // S = diag_mask_inf(qcur @ kcur.T * scale, P)
- // SM = softmax(S)
- // S = d[:D,iq1,iq2,iq3] @ vcur
- // dot_SM_gradSM = dot(SM, S)
- // S = SM * (S - dot(SM, S))
- // S = diag_mask_zero(S, P) * scale
- //
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
- // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
- }
-
- // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
- // S = d[:D,iq1,iq2,iq3] @ vcur
- // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
- ggml_vec_set_f32(M, S, 0);
- for (int64_t ic = 0; ic < D; ++ic) {
- // dst indices
- const int i1 = iq1;
- const int i2 = iq2;
- const int i3 = iq3;
+ }
- ggml_vec_mad_f32(M,
- S,
- (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
- *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
- }
+ assert(sum > 0.0);
- // S = SM * (S - dot(SM, S))
- float dot_SM_gradSM = 0;
- ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
- ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
- ggml_vec_mul_f32 (M, S, S, SM);
+ sum = 1.0/sum;
+ ggml_vec_scale_f32(masked_begin, SM, sum);
- // S = diag_mask_zero(S, P) * scale
- if (masked) {
- // for (int64_t i = P + iq1 + 1; i < M; i++) {
- // S[i] = 0;
- // }
- for (int64_t i = P; i < M; i++) {
- if (i > P + iq1) {
- S[i] = 0;
- }
}
- }
- ggml_vec_scale_f32(M, S, scale);
-
- void * grad_q = (char *) dst->data;
- void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
- void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
-
- const size_t nbgq1 = nb0*neq0;
- const size_t nbgq2 = nb0*neq0*neq1;
- const size_t nbgq3 = nb0*neq0*neq1*neq2;
-
- const size_t nbgk1 = nb0*nek0;
- const size_t nbgk2 = nb0*nek0*nek1;
- const size_t nbgk3 = nb0*nek0*nek1*neq2;
-
- const size_t nbgv1 = nb0*nev0;
- const size_t nbgv2 = nb0*nev0*nev1;
- const size_t nbgv3 = nb0*nev0*nev1*neq2;
-
- // S shape [M,1]
- // SM shape [M,1]
- // kcur shape [D,M]
- // qcur shape [D,1]
- // vcur shape [M,D]
- //
- // grad[q][:D,iq1,iq2,iq3] += S @ kcur
- // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
- // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
- //
- //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
- //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
- for (int64_t ic = 0; ic < M; ++ic) {
- // dst indices
- const int i1 = iq1;
- const int i2 = iq2;
- const int i3 = iq3;
- ggml_vec_mad_f32(D,
- (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
- (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
- S[ic]);
- }
+ // step-by-step explanation
+ {
+ // forward-process shape grads from backward process
+ // parallel_for ik2,ik3:
+ // for irep:
+ // iq2 = ik2 + irep*nek2
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
+ // for iq1:
+ // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
+ // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
+ // S0 = -Inf [D,1,1,1]
+ // ~S1[i] = dot(kcur[:D,i], qcur)
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
+ // ~S5[i] = dot(vcur[:,i], S4)
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
+ // dst backward-/ grad[dst] = d
+ //
+ // output gradients with their dependencies:
+ //
+ // grad[kcur] = grad[S1].T @ qcur
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
+ // grad[S4] = grad[S5] @ vcur
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
+ // grad[qcur] = grad[S1] @ kcur
+ // grad[vcur] = grad[S5].T @ S4
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+ //
+ // in post-order:
+ //
+ // S1 = qcur @ kcur.T
+ // S2 = S1 * scale
+ // S3 = diag_mask_inf(S2, P)
+ // S4 = softmax(S3)
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
+ // grad[qcur] = grad[S1] @ kcur
+ // grad[kcur] = grad[S1].T @ qcur
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+ //
+ // using less variables (SM=S4):
+ //
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
+ // SM = softmax(S)
+ // S = d[:D,iq1,iq2,iq3] @ vcur
+ // dot_SM_gradSM = dot(SM, S)
+ // S = SM * (S - dot(SM, S))
+ // S = diag_mask_zero(S, P) * scale
+ //
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
+ // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
+ }
- // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
- // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
- // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
- for (int64_t ic = 0; ic < M; ++ic) {
- // dst indices
- const int i1 = iq1;
- const int i2 = iq2;
- const int i3 = iq3;
+ // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+ // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+ // for ic:
+ // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
+ // exclude known future zero S[..] values from operation
+ ggml_vec_set_f32(masked_begin, S, 0);
+ for (int64_t ic = 0; ic < D; ++ic) {
+ ggml_vec_mad_f32(masked_begin,
+ S,
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
+ }
- // ggml_vec_set_f32(D,
- // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
- // 0);
- ggml_vec_mad_f32(D,
- (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
- (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
- S[ic]);
- }
+ // S = SM * (S - dot(SM, S))
+ float dot_SM_gradSM = 0;
+ ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
+ ggml_vec_mul_f32 (masked_begin, S, S, SM);
+
+ // S = diag_mask_zero(S, P) * scale
+ // already done by above ggml_vec_set_f32
+
+ // exclude known zero S[..] values from operation
+ ggml_vec_scale_f32(masked_begin, S, scale);
+
+ // S shape [M,1]
+ // SM shape [M,1]
+ // kcur shape [D,M]
+ // qcur shape [D,1]
+ // vcur shape [M,D]
+
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
+ // for ic:
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
+ // exclude known zero S[..] values from loop
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
+ ggml_vec_mad_f32(D,
+ (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
+ (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
+ S[ic]);
+ }
- // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
- // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
- // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
- for (int64_t ic = 0; ic < D; ++ic) {
- // dst indices
- const int i1 = iq1;
- const int i2 = iq2;
- const int i3 = iq3;
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
+ // for ic:
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
+ // exclude known zero S[..] values from loop
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
+ ggml_vec_mad_f32(D,
+ (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
+ S[ic]);
+ }
- // ggml_vec_set_f32(M,
- // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
- // 0);
- ggml_vec_mad_f32(M,
- (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
- SM,
- *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
+ // for ic:
+ // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
+ // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
+ // exclude known zero SM[..] values from mad
+ for (int64_t ic = 0; ic < D; ++ic) {
+ ggml_vec_mad_f32(masked_begin,
+ (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
+ SM,
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
+ }
}
}
}
} break;
case GGML_OP_GET_ROWS_BACK:
{
- ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+ ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_DIAG:
{
////////////////////////////////////////////////////////////////////////////////
-static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
+static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
+
+static size_t hash(void * p) {
+ return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
+}
+
+static size_t hash_find(void * hash_table[], void * p) {
+ size_t h = hash(p);
+
+ // linear probing
+ size_t i = h;
+ while (hash_table[i] != NULL && hash_table[i] != p) {
+ i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
+ if (i == h) {
+ // visited all hash table entries -> not found
+ return GGML_GRAPH_HASHTABLE_SIZE;
+ }
+ }
+ return i;
+}
+
+static bool hash_insert(void * hash_table[], void * p) {
+ size_t i = hash_find(hash_table, p);
+
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+
+ if (hash_table[i] == p) {
+ return true;
+ }
+
+ // insert
+ GGML_ASSERT(hash_table[i] == NULL);
+ hash_table[i] = p;
+ return false;
+}
+
+static bool hash_contains(void * hash_table[], void * p) {
+ size_t i = hash_find(hash_table, p);
+ return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
+}
+
+struct hash_map {
+ void * keys[GGML_GRAPH_HASHTABLE_SIZE];
+ void * vals[GGML_GRAPH_HASHTABLE_SIZE];
+};
+
+static struct hash_map * new_hash_map(void) {
+ struct hash_map * result = malloc(sizeof(struct hash_map));
+ for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
+ result->keys[i] = NULL;
+ result->vals[i] = NULL;
+ }
+ return result;
+}
+
+static void free_hash_map(struct hash_map * map) {
+ free(map);
+}
+
+// gradient checkpointing
+
+static struct ggml_tensor * ggml_recompute_graph_node(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * graph,
+ struct hash_map * replacements,
+ struct ggml_tensor * node) {
+
+ if (node == NULL) {
+ return NULL;
+ }
+
+ if (node->is_param) {
+ return node;
+ }
+
+ if (!hash_contains(graph->visited_hash_table, node)) {
+ return node;
+ }
+
+ int count_children = 0;
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
+ if (node->src[k]) {
+ ++count_children;
+ }
+ }
+
+ if (count_children == 0) {
+ return node;
+ }
+
+ size_t i = hash_find(replacements->keys, node);
+ GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+ if (replacements->keys[i] == node) {
+ return (struct ggml_tensor *) replacements->vals[i];
+ }
+
+ struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
+
+ // insert clone into replacements
+ GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
+ replacements->keys[i] = node;
+ replacements->vals[i] = clone;
+
+ clone->op = node->op;
+ clone->grad = node->grad;
+ clone->is_param = node->is_param;
+ clone->extra = node->extra;
+ for (int k = 0; k < GGML_MAX_DIMS; ++k) {
+ clone->nb[k] = node->nb[k];
+ }
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
+ clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
+ }
+ if (node->view_src != NULL) {
+ clone->data = (node->view_src->data == NULL)
+ ? NULL // view_src not yet allocated
+ : (char *) node->view_src->data // view_src already allocated
+ + node->view_offs;
+ clone->view_src = node->view_src;
+ clone->view_offs = node->view_offs;
+ }
+
+ GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
+ GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
+ memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
+ ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
+
+ return clone;
+}
+
+void ggml_build_backward_gradient_checkpointing(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ struct ggml_cgraph * gb_tmp,
+ struct ggml_tensor * * checkpoints,
+ int n_checkpoints) {
+ *gb_tmp = *gf;
+ ggml_build_backward_expand(ctx, gf, gb_tmp, true);
+
+ if (n_checkpoints <= 0) {
+ *gb = *gb_tmp;
+ return;
+ }
+
+ struct hash_map * replacements = new_hash_map();
+
+ // insert checkpoints in replacements
+ for (int i = 0; i < n_checkpoints; ++i) {
+ size_t k = hash_find(replacements->keys, checkpoints[i]);
+ GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+ GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
+ replacements->keys[k] = checkpoints[i];
+ replacements->vals[k] = checkpoints[i];
+ }
+
+ *gb = *gf;
+ // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
+ // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
+ // by recomputing them from checkpoints
+ for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
+ struct ggml_tensor * node = gb_tmp->nodes[i];
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
+ // insert new tensors recomputing src, reusing already made replacements,
+ // remember replacements: remember new tensors with mapping from corresponding gf nodes
+ // recurse for input tensors,
+ // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
+ node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
+ }
+ // insert rewritten backward node with replacements made into resulting backward graph gb
+ ggml_build_forward_expand(gb, node);
+ }
+
+ free_hash_map(replacements);
+}
+
+// functions to change gradients considering the case that input a might be initial gradient with zero value
+
+static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
+ if (hash_contains(zero_table, a)) {
+ return b;
+ } else {
+ return ggml_add_impl(ctx, a, b, false);
+ }
+}
+
+static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
+ if (hash_contains(zero_table, a)) {
+ struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
+ return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
+ } else {
+ return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
+ }
+}
+
+static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
+ if (hash_contains(zero_table, a)) {
+ return ggml_repeat(ctx, b, a);
+ } else {
+ return ggml_add1_impl(ctx, a, b, false);
+ }
+}
+
+static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
+ if (hash_contains(zero_table, a)) {
+ return ggml_neg(ctx, b);
+ } else {
+ return ggml_sub_impl(ctx, a, b, false);
+ }
+}
+
+static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) {
struct ggml_tensor * src0 = tensor->src[0];
struct ggml_tensor * src1 = tensor->src[1];
case GGML_OP_DUP:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
} break;
case GGML_OP_ADD:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
- src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
+ src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
}
} break;
case GGML_OP_ADD1:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
- src1->grad = ggml_add_impl(ctx,
+ src1->grad = ggml_add_or_set(ctx,
src1->grad,
ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
- inplace);
+ zero_table);
}
} break;
case GGML_OP_ACC:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
nb1, nb2, nb3, offset);
src1->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src1->grad,
ggml_reshape(ctx,
ggml_cont(ctx, tensor_grad_view),
src1->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SUB:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
- src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace);
+ src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
}
} break;
case GGML_OP_MUL:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_mul(ctx, src1, tensor->grad),
- inplace);
+ zero_table);
}
if (src1->grad) {
src1->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src1->grad,
ggml_mul(ctx, src0, tensor->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_DIV:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_div(ctx, tensor->grad, src1),
- inplace);
+ zero_table);
}
if (src1->grad) {
src1->grad =
- ggml_sub_impl(ctx,
+ ggml_sub_or_set(ctx,
src1->grad,
ggml_mul(ctx,
tensor->grad,
ggml_div(ctx, tensor, src1)),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SQR:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_scale(ctx,
ggml_mul(ctx, src0, tensor->grad),
ggml_new_f32(ctx, 2.0f)),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SQRT:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_scale(ctx,
ggml_div(ctx,
tensor->grad,
tensor),
ggml_new_f32(ctx, 0.5f)),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_LOG:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_div(ctx,
tensor->grad,
src0),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SUM:
{
if (src0->grad) {
src0->grad =
- ggml_add1_impl(ctx,
+ ggml_add1_or_set(ctx,
src0->grad,
tensor->grad,
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SUM_ROWS:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_repeat(ctx,
tensor->grad,
src0->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_MEAN:
{
// necessary for llama
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_repeat_back(ctx, tensor->grad, src0->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_REPEAT_BACK:
{
if (src0->grad) {
// TODO: test this
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_repeat(ctx, tensor->grad, src0->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_CONCAT:
float eps;
memcpy(&eps, tensor->op_params, sizeof(float));
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_RMS_NORM_BACK:
// ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
// ds1 = t.T.dot(dt)
- // tensor.shape [m,p]
- // src0.shape [n,m]
- // src1.shape [n,p]
+ // tensor.shape [m,p,qq,rr]
+ // src0.shape [n,m,q1,r1]
+ // src1.shape [n,p,qq,rr]
// necessary for llama
if (src0->grad) {
+ struct ggml_tensor * s1_tg =
+ ggml_out_prod(ctx, // [n,m,qq,rr]
+ src1, // [n,p,qq,rr]
+ tensor->grad); // [m,p,qq,rr]
+ const int64_t qq = s1_tg->ne[2];
+ const int64_t rr = s1_tg->ne[3];
+ const int64_t q1 = src0->ne[2];
+ const int64_t r1 = src0->ne[3];
+ const bool ne2_broadcasted = qq > q1;
+ const bool ne3_broadcasted = rr > r1;
+ if (ne2_broadcasted || ne3_broadcasted) {
+ // sum broadcast repetitions of s1_tg into shape of src0
+ s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
+ }
src0->grad =
- ggml_add_impl(ctx,
- src0->grad,
- ggml_out_prod(ctx, // [n,m]
- src1, // [n,p]
- tensor->grad), // [m,p]
- inplace);
+ ggml_add_or_set(ctx,
+ src0->grad, // [n,m,q1,r1]
+ s1_tg, // [n,m,q1,r1]
+ zero_table);
}
if (src1->grad) {
src1->grad =
- ggml_add_impl(ctx,
- src1->grad,
- // ggml_mul_mat(ctx, // [n,p]
- // ggml_cont(ctx, // [m,n]
- // ggml_transpose(ctx, src0)), // [m,n]
- // tensor->grad), // [m,p]
+ ggml_add_or_set(ctx,
+ src1->grad, // [n,p,qq,rr]
+ // ggml_mul_mat(ctx, // [n,p,qq,rr]
+ // ggml_cont(ctx, // [m,n,q1,r1]
+ // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
+ // tensor->grad), // [m,p,qq,rr]
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
// // avoid transpose of src0, rather transpose smaller tensor->grad
// // and then use ggml_out_prod
- ggml_out_prod(ctx, // [n,p]
- src0, // [n,m]
- ggml_transpose(ctx, // [p,m]
- tensor->grad)), // [m,p]
- inplace);
+ ggml_out_prod(ctx, // [n,p,qq,rr]
+ src0, // [n,m,q1,r1]
+ ggml_transpose(ctx, // [p,m,qq,rr]
+ tensor->grad)), // [m,p,qq,rr]
+ zero_table);
}
} break;
case GGML_OP_OUT_PROD:
// necessary for llama
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_scale_impl(ctx, tensor->grad, src1, false),
- inplace);
+ zero_table);
}
if (src1->grad) {
src1->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src1->grad,
ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SET:
}
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_acc_impl(ctx,
tensor->grad,
ggml_neg(ctx, tensor_grad_view),
nb1, nb2, nb3, offset, false),
- inplace);
+ zero_table);
}
if (src1->grad) {
src1->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src1->grad,
ggml_reshape(ctx,
ggml_cont(ctx, tensor_grad_view),
src1->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_CPY:
// tensor = src0 * 1 + src1 * 0
if (src0->grad) {
// dsrc0 = dtensor * 1
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
// dsrc1 = dtensor * 0 -> noop
if (src0->grad) {
GGML_ASSERT(ggml_is_contiguous(src0->grad));
GGML_ASSERT(ggml_is_contiguous(tensor->grad));
- src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
} break;
case GGML_OP_RESHAPE:
// necessary for llama
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx, src0->grad,
- ggml_reshape(ctx, tensor->grad, src0->grad),
- inplace);
+ ggml_add_or_set(ctx, src0->grad,
+ ggml_reshape(ctx,
+ ggml_is_contiguous(tensor->grad)
+ ? tensor->grad
+ : ggml_cont(ctx, tensor->grad),
+ src0->grad),
+ zero_table);
}
} break;
case GGML_OP_VIEW:
nb3 = (nb3 / n0) * ng;
}
- src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace);
+ src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
}
} break;
case GGML_OP_PERMUTE:
axes_backward[axis2] = 2;
axes_backward[axis3] = 3;
src0->grad =
- ggml_add_impl(ctx, src0->grad,
+ ggml_add_or_set(ctx, src0->grad,
ggml_permute(ctx,
tensor->grad,
axes_backward[0],
axes_backward[1],
axes_backward[2],
axes_backward[3]),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_TRANSPOSE:
// necessary for llama
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx, src0->grad,
+ ggml_add_or_set(ctx, src0->grad,
ggml_transpose(ctx, tensor->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_GET_ROWS:
// necessary for llama (only for tokenizer)
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx, src0->grad,
+ ggml_add_or_set(ctx, src0->grad,
+ // last ggml_get_rows_back argument src0->grad is only
+ // necessary to setup correct output shape
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
- inplace);
+ zero_table);
}
if (src1->grad) {
// noop
if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0];
src0->grad =
- ggml_add_impl(ctx, src0->grad,
+ ggml_add_or_set(ctx, src0->grad,
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_DIAG_MASK_ZERO:
if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0];
src0->grad =
- ggml_add_impl(ctx, src0->grad,
+ ggml_add_or_set(ctx, src0->grad,
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_SOFT_MAX:
// necessary for llama
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx, src0->grad,
+ ggml_add_or_set(ctx, src0->grad,
ggml_soft_max_back(ctx, tensor->grad, tensor),
- inplace);
+ zero_table);
}
} break;
memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_rope_back(ctx,
tensor->grad,
freq_scale,
xpos_base,
xpos_down),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_ROPE_BACK:
memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_rope_impl(ctx,
tensor->grad,
xpos_base,
xpos_down,
false),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_ALIBI:
masked);
}
- if (src0->grad) {
- struct ggml_tensor * grad_q = NULL;
- const size_t nb0 = flash_grad->nb[0];
- const size_t offset = 0;
- switch(src0->n_dims) {
- case 2:
- {
- grad_q = ggml_view_2d(ctx,
- flash_grad,
- src0->ne[0],
- src0->ne[1],
- nb0*src0->ne[0],
- offset);
- } break;
- case 3:
- {
- grad_q = ggml_view_3d(ctx,
- flash_grad,
- src0->ne[0],
- src0->ne[1],
- src0->ne[2],
- nb0*src0->ne[0],
- nb0*src0->ne[0]*src0->ne[1],
- offset);
- } break;
- case 4:
- {
- grad_q = ggml_view_4d(ctx,
- flash_grad,
- src0->ne[0],
- src0->ne[1],
- src0->ne[2],
- src0->ne[3],
- nb0*src0->ne[0],
- nb0*src0->ne[0]*src0->ne[1],
- nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
- offset);
- } break;
- }
+ struct ggml_tensor * src2 = tensor->src[2];
+ const int64_t elem_q = ggml_nelements(src0);
+ const int64_t elem_k = ggml_nelements(src1);
+ const int64_t elem_v = ggml_nelements(src2);
+
+ enum ggml_type result_type = flash_grad->type;
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
+ const size_t tsize = ggml_type_size(result_type);
+
+ const size_t offs_q = 0;
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
- src0->grad = ggml_add_impl(ctx,
+ if (src0->grad) {
+ struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
+ struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
grad_q,
- inplace);
+ zero_table);
}
-
if (src1->grad) {
- struct ggml_tensor * grad_k = NULL;
- const size_t nb0 = flash_grad->nb[0];
- const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
- switch(src1->n_dims) {
- case 2:
- {
- grad_k = ggml_view_2d(ctx,
- flash_grad,
- src1->ne[0],
- src1->ne[1],
- nb0*src1->ne[0],
- offset);
- } break;
- case 3:
- {
- grad_k = ggml_view_3d(ctx,
- flash_grad,
- src1->ne[0],
- src1->ne[1],
- src1->ne[2],
- nb0*src1->ne[0],
- nb0*src1->ne[0]*src1->ne[1],
- offset);
- } break;
- case 4:
- {
- grad_k = ggml_view_4d(ctx,
- flash_grad,
- src1->ne[0],
- src1->ne[1],
- src1->ne[2],
- src1->ne[3],
- nb0*src1->ne[0],
- nb0*src1->ne[0]*src1->ne[1],
- nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
- offset);
- } break;
- }
-
- src1->grad = ggml_add_impl(ctx,
+ struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
+ struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
+ src1->grad = ggml_add_or_set(ctx,
src1->grad,
grad_k,
- inplace);
+ zero_table);
}
-
- struct ggml_tensor * opt0 = tensor->src[2];
-
- if (opt0->grad) {
- struct ggml_tensor * grad_v = NULL;
- const size_t nb0 = flash_grad->nb[0];
- const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
- + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
- switch(opt0->n_dims) {
- case 2:
- {
- grad_v = ggml_view_2d(ctx,
- flash_grad,
- opt0->ne[0],
- opt0->ne[1],
- nb0*opt0->ne[0],
- offset);
- } break;
- case 3:
- {
- grad_v = ggml_view_3d(ctx,
- flash_grad,
- opt0->ne[0],
- opt0->ne[1],
- opt0->ne[2],
- nb0*opt0->ne[0],
- nb0*opt0->ne[0]*opt0->ne[1],
- offset);
- } break;
- case 4:
- {
- grad_v = ggml_view_4d(ctx,
- flash_grad,
- opt0->ne[0],
- opt0->ne[1],
- opt0->ne[2],
- opt0->ne[3],
- nb0*opt0->ne[0],
- nb0*opt0->ne[0]*opt0->ne[1],
- nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
- offset);
- } break;
- }
-
- opt0->grad = ggml_add_impl(ctx,
- opt0->grad,
+ if (src2->grad) {
+ struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
+ struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
+ src2->grad = ggml_add_or_set(ctx,
+ src2->grad,
grad_v,
- inplace);
+ zero_table);
}
} break;
case GGML_OP_FLASH_FF:
{
if (src0->grad) {
src0->grad =
- ggml_add_impl(ctx,
+ ggml_add_or_set(ctx,
src0->grad,
ggml_mul(ctx,
ggml_sgn(ctx, src0),
tensor->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_NEG:
{
if (src0->grad) {
- src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
+ src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
} break;
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_RELU:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_mul(ctx,
ggml_step(ctx, src0),
tensor->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_UNARY_OP_GELU:
{
// necessary for llama
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_silu_back(ctx, src0, tensor->grad),
- inplace);
+ zero_table);
}
} break;
default:
case GGML_OP_CROSS_ENTROPY_LOSS:
{
if (src0->grad) {
- src0->grad = ggml_add_impl(ctx,
+ src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_cross_entropy_loss_back(ctx,
src0,
src1,
tensor->grad),
- inplace);
+ zero_table);
}
} break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
GGML_ASSERT(false);
} break;
}
-}
-static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
-
-static size_t hash(void * p) {
- return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
-}
-
-static bool hash_insert(void * hash_table[], void * p) {
- size_t h = hash(p);
-
- // linear probing
- size_t i = h;
- while (hash_table[i] != NULL && hash_table[i] != p) {
- i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
- if (i == h) {
- // hash table is full
- GGML_ASSERT(false);
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
+ if (tensor->src[i] && tensor->src[i]->grad) {
+ GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
}
}
-
- if (hash_table[i] == p) {
- return true;
- }
-
- // insert
- hash_table[i] = p;
- return false;
}
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
}
for (int i = 0; i < GGML_MAX_SRC; ++i) {
- if (node->src[i]) {
- ggml_visit_parents(cgraph, node->src[i]);
+ const int k =
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
+ /* unknown order, just fall back to using i*/ i;
+ if (node->src[k]) {
+ ggml_visit_parents(cgraph, node->src[k]);
}
}
/*.grads =*/ { NULL },
/*.leafs =*/ { NULL },
/*.hash_table =*/ { NULL },
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
/*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0,
}
}
+ // remember original gradients which start with zero values
+ void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
+ memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
+ for (int i = 0; i < gf->n_nodes; i++) {
+ if (gf->grads[i]) {
+ hash_insert(zero_table, gf->grads[i]);
+ }
+ }
+
for (int i = gf->n_nodes - 1; i >= 0; i--) {
struct ggml_tensor * node = gf->nodes[i];
- // because we detached the grad nodes from the original graph, we can afford inplace operations
+ // inplace operations to add gradients are not created by ggml_compute_backward
+ // use allocator to automatically make inplace operations
if (node->grad) {
- ggml_compute_backward(ctx, node, keep);
+ ggml_compute_backward(ctx, node, zero_table);
}
}
ggml_build_forward_expand(gb, node->grad);
}
}
+
+ free(zero_table);
}
struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
/*.grads =*/ { NULL },
/*.leafs =*/ { NULL },
/*.hash_table =*/ { NULL },
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
/*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0,
} break;
case GGML_OP_CONCAT:
case GGML_OP_MUL_MAT:
- case GGML_OP_OUT_PROD:
{
n_tasks = n_threads;
cur = 0;
}
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_OUT_PROD:
+ {
+ n_tasks = n_threads;
+
+ size_t cur = 0;
+
+ if (ggml_is_quantized(node->src[0]->type)) {
+ cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
+ }
+
work_size = MAX(work_size, cur);
} break;
case GGML_OP_SCALE:
}
static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
- int i = 0;
+ int64_t i = 0;
for (int p = 0; p < np; ++p) {
const int64_t ne = ggml_nelements(ps[p]) ;
// TODO: add function to get all elements at once
}
}
+static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) {
+ int64_t i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int64_t ne = ggml_nelements(ps[p]) ;
+ // TODO: add function to get all elements at once
+ for (int64_t j = 0; j < ne; ++j) {
+ g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale;
+ }
+ }
+}
+
//
// ADAM
//
const float eps = params.adam.eps;
const float gclip = params.adam.gclip;
const int decay_min_ndim = params.adam.decay_min_ndim;
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
+ const float accum_norm = 1.0f / (float) n_accum;
+ float * g = opt->adam.g->data; // gradients
float * m = opt->adam.m->data; // first moment
float * v = opt->adam.v->data; // second moment
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
- if (callback) {
- callback(callback_data, &sched);
- }
-
- // compute the function value
- ggml_graph_reset (gf);
- ggml_set_f32 (f->grad, 1.0f);
-
struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
- ggml_graph_compute(gb, &cplan);
- opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
+ bool cancel = false;
+
+ // compute the function value
+ float fx = 0;
+ ggml_set_zero(opt->adam.g);
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+ if (callback) {
+ callback(callback_data, accum_step, &sched, &cancel);
+ if (cancel) {
+ break;
+ }
+ }
+ // ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(gb, &cplan);
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
+ fx += ggml_get_f32_1d(f, 0);
+ }
+ if (cancel) {
+ return GGML_OPT_DID_NOT_CONVERGE;
+ }
+ fx *= accum_norm;
+
+ opt->adam.fx_prev = fx;
opt->adam.fx_best = opt->adam.fx_prev;
if (pf) {
pf[opt->iter % params.past] = opt->adam.fx_prev;
// run the optimizer
for (int t = 0; t < params.adam.n_iter; ++t) {
+ if (cancel) {
+ break;
+ }
opt->iter = iter0 + t + 1;
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
if (gclip > 0.0f) {
// gradient clipping
ggml_float sum = 0.0;
- for (int p = 0; p < np; ++p) {
- const int64_t ne = ggml_nelements(ps[p]);
- for (int64_t j = 0; j < ne; ++j) {
- float g = ggml_get_f32_1d(ps[p]->grad, j);
- sum += (ggml_float)(g*g);
- }
+ for (int64_t i = 0; i < nx; ++i) {
+ sum += (ggml_float)(g[i]*g[i]);
}
ggml_float norm = sqrt(sum);
if (norm > (ggml_float) gclip) {
const int64_t ne = ggml_nelements(ps[p]);
const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
for (int64_t j = 0; j < ne; ++j) {
- float x = ggml_get_f32_1d(ps[p], j);
- float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
- m[i] = m[i]*beta1 + g*(1.0f - beta1);
- v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
+ float x = ggml_get_f32_1d(ps[p], j);
+ float g_ = g[i]*gnorm;
+ m[i] = m[i]*beta1 + g_*(1.0f - beta1);
+ v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2);
float mh = m[i]*beta1h;
float vh = v[i]*beta2h;
vh = sqrtf(vh) + eps;
}
}
- if (callback) {
- callback(callback_data, &sched);
+ fx = 0;
+ ggml_set_zero(opt->adam.g);
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+ if (callback) {
+ callback(callback_data, accum_step, &sched, &cancel);
+ if (cancel) {
+ break;
+ }
+ }
+ // ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(gb, &cplan);
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
+ fx += ggml_get_f32_1d(f, 0);
}
+ if (cancel) {
+ break;
+ }
+ fx *= accum_norm;
- ggml_graph_reset (gf);
- ggml_set_f32 (f->grad, 1.0f);
-
- ggml_graph_compute(gb, &cplan);
-
- const float fx = ggml_get_f32_1d(f, 0);
opt->loss_after = fx;
float * step,
const float * xp,
struct ggml_tensor * f,
- struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
struct ggml_cplan * cplan,
const int np,
struct ggml_tensor * ps[],
+ bool * cancel,
ggml_opt_callback callback,
void * callback_data) {
int count = 0;
const float dec = 0.5f;
const float inc = 2.1f;
+ const int n_accum = MAX(1, params->n_gradient_accumulation);
+ const float accum_norm = 1.0f / (float) n_accum;
+
if (*step <= 0.f) {
return GGML_LINESEARCH_INVALID_PARAMETERS;
}
finit = *fx;
dgtest = params->lbfgs.ftol*dginit;
- while (true) {
- if (callback) {
- // LBFG-S does not support learning rate -> ignore learning schedule
- float sched = 0;
- callback(callback_data, &sched);
- }
-
+ while (!*cancel) {
ggml_vec_cpy_f32(nx, x, xp);
ggml_vec_mad_f32(nx, x, d, *step);
{
ggml_opt_set_params(np, ps, x);
- ggml_graph_reset (gf);
- ggml_set_f32 (f->grad, 1.0f);
-
- ggml_graph_compute(gb, cplan);
-
- ggml_opt_get_grad(np, ps, g);
+ *fx = 0;
+ memset(g, 0, sizeof(float)*nx);
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+ if (callback) {
+ // LBFG-S does not support learning rate -> ignore learning schedule
+ float sched = 0;
+ callback(callback_data, accum_step, &sched, cancel);
+ if (*cancel) {
+ break;
+ }
+ }
+ // ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(gb, cplan);
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
+ *fx += ggml_get_f32_1d(f, 0);
+ }
+ if (*cancel) {
+ break;
+ }
+ *fx *= accum_norm;
- *fx = ggml_get_f32_1d(f, 0);
}
++count;
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
+ const float accum_norm = 1.0f / (float) n_accum;
+
float fx = 0.0f; // cost function value
float xnorm = 0.0f; // ||x||
float gnorm = 0.0f; // ||g||
float * lm_s = opt->lbfgs.lms->data;
float * lm_y = opt->lbfgs.lmy->data;
- if (callback) {
- // LBFG-S does not support learning rate -> ignore learning schedule
- float sched = 0;
- callback(callback_data, &sched);
- }
+ bool cancel = false;
// evaluate the function value and its gradient
{
ggml_opt_set_params(np, ps, x);
- ggml_graph_reset (gf);
- ggml_set_f32 (f->grad, 1.0f);
-
- ggml_graph_compute(gb, &cplan);
-
- ggml_opt_get_grad(np, ps, g);
-
- fx = ggml_get_f32_1d(f, 0);
+ fx = 0;
+ memset(g, 0, sizeof(float)*nx);
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+ if (callback) {
+ // LBFG-S does not support learning rate -> ignore learning schedule
+ float sched = 0;
+ callback(callback_data, accum_step, &sched, &cancel);
+ if (cancel) {
+ break;
+ }
+ }
+ // ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(gb, &cplan);
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
+ fx += ggml_get_f32_1d(f, 0);
+ }
+ if (cancel) {
+ return GGML_OPT_DID_NOT_CONVERGE;
+ }
+ fx *= accum_norm;
opt->loss_before = fx;
opt->loss_after = fx;
ggml_vec_cpy_f32(nx, xp, x);
ggml_vec_cpy_f32(nx, gp, g);
- ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data);
+ ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
+ if (!cancel) {
+ break;
+ }
if (ls < 0) {
// linesearch failed - go back to the previous point and return
.print_forward_graph = true,
.print_backward_graph = true,
+ .n_gradient_accumulation = 1,
+
.adam = {
.n_iter = 10000,
.sched = 1.000f,
.print_forward_graph = true,
.print_backward_graph = true,
+ .n_gradient_accumulation = 1,
+
.lbfgs = {
.m = 6,
.n_iter = 100,
opt->iter = 0;
opt->nx = nx;
opt->just_initialized = true;
+ if (opt->ctx == NULL) {
+ struct ggml_init_params ctx_opt_params;
+ if (opt->params.type == GGML_OPT_ADAM) {
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3;
+ if (opt->params.past > 0) {
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
+ }
+ } else if (opt->params.type == GGML_OPT_LBFGS) {
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
+ if (opt->params.past > 0) {
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
+ }
+ }
+ ctx_opt_params.mem_buffer = NULL;
+ ctx_opt_params.no_alloc = false;
+
+ opt->ctx = ggml_init(ctx_opt_params);
+ }
switch (opt->params.type) {
case GGML_OPT_ADAM:
{
- opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
- opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+ opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->adam.pf = params.past > 0
- ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
: NULL;
ggml_set_zero(opt->adam.m);
ggml_set_zero(opt->adam.v);
} break;
case GGML_OPT_LBFGS:
{
- opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
- opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
- opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
- opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
- opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->lbfgs.pf = params.past > 0
- ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
: NULL;
- opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
- opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
- opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
- opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
+ opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
+ opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
+ opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
+ opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
ggml_set_zero(opt->lbfgs.x);
ggml_set_zero(opt->lbfgs.xp);
ggml_set_zero(opt->lbfgs.g);
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
#define GGML_MAX_DIMS 4
-#define GGML_MAX_NODES 4096
-#define GGML_MAX_PARAMS 256
+#define GGML_MAX_NODES 16384
+#define GGML_MAX_PARAMS 1024
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 6
#define GGML_MAX_NAME 64
// next prime after GGML_MAX_NODES
// #define GGML_GRAPH_HASHTABLE_SIZE 4099
// next prime after GGML_MAX_NODES * 2 (nodes + leafs)
- #define GGML_GRAPH_HASHTABLE_SIZE 8273
+ // #define GGML_GRAPH_HASHTABLE_SIZE 8273
+ // #define GGML_GRAPH_HASHTABLE_SIZE 16411
+ #define GGML_GRAPH_HASHTABLE_SIZE 32771
+
+ enum ggml_cgraph_eval_order {
+ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
+ GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
+ GGML_CGRAPH_EVAL_ORDER_COUNT
+ };
// computation graph
struct ggml_cgraph {
void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
+ enum ggml_cgraph_eval_order order;
+
// performance
int perf_runs;
int64_t perf_cycles;
GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
+ // Converts a flat index into coordinates
+ GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
+
GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
+ GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+ GGML_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
+
GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
+ GGML_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+ GGML_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
+
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
struct ggml_tensor * a,
struct ggml_tensor * b);
+ GGML_API struct ggml_tensor * ggml_add_cast(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_type type);
+
GGML_API struct ggml_tensor * ggml_add1(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * a,
struct ggml_tensor * b);
+ // sums repetitions in a into shape of b
GGML_API struct ggml_tensor * ggml_repeat_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
// dump the graph into a file using the dot format
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
+ // build gradient checkpointing backward graph gb for gf using provided checkpoints
+ // gb_tmp will contain original backward graph with rewritten backward process nodes,
+ // but without the second forward pass nodes.
+ GGML_API void ggml_build_backward_gradient_checkpointing(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ struct ggml_cgraph * gb_tmp,
+ struct ggml_tensor * * checkpoints,
+ int n_checkpoints);
//
// optimization
//
GGML_LINESEARCH_INVALID_PARAMETERS,
};
- typedef void (*ggml_opt_callback)(void * data, float * sched);
+ typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
// optimization parameters
bool print_forward_graph;
bool print_backward_graph;
+ int n_gradient_accumulation;
+
// ADAM parameters
struct {
int n_iter;
float loss_after;
struct {
+ struct ggml_tensor * g; // current gradient
struct ggml_tensor * m; // first moment
struct ggml_tensor * v; // second moment
struct ggml_tensor * pf; // past function values
// TODO: after the GGUF PR, this likely won't work and needs to be updated
static int llama_apply_lora_from_file_internal(
- const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads
+ const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
) {
LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
int32_t lora_alpha;
fin.read((char *) &lora_r, sizeof(lora_r));
fin.read((char *) &lora_alpha, sizeof(lora_alpha));
- float scaling = (float)lora_alpha / (float)lora_r;
+ float scaling = scale * (float)lora_alpha / (float)lora_r;
LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
ggml_set_name(r, "r_cpy");
}
- struct ggml_cgraph gf = ggml_build_forward(r);
+ struct ggml_cgraph * gf = ggml_new_graph(lora_ctx);
+ ggml_build_forward_expand(gf, r);
- ggml_graph_compute_helper(work_buffer, &gf, n_threads);
+ ggml_graph_compute_helper(work_buffer, gf, n_threads);
// we won't need these tensors again, reset the context to save memory
ggml_free(lora_ctx);
return nparams;
}
+struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) {
+ return ggml_get_tensor(model->ctx, name);
+}
+
int llama_model_quantize(
const char * fname_inp,
const char * fname_out,
}
}
-int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
+int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, float scale, const char * path_base_model, int n_threads) {
try {
- return llama_apply_lora_from_file_internal(ctx->model, path_lora, path_base_model, n_threads);
+ return llama_apply_lora_from_file_internal(ctx->model, path_lora, scale, path_base_model, n_threads);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
return 1;
}
}
-int llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, const char * path_base_model, int n_threads) {
+int llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int n_threads) {
try {
- return llama_apply_lora_from_file_internal(*model, path_lora, path_base_model, n_threads);
+ return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
return 1;
// Returns the total number of parameters in the model
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
+ // Get a llama model tensor
+ LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
+
// Returns 0 on success
LLAMA_API int llama_model_quantize(
const char * fname_inp,
LLAMA_API DEPRECATED(int llama_apply_lora_from_file(
struct llama_context * ctx,
const char * path_lora,
+ float scale,
const char * path_base_model,
int n_threads),
"use llama_model_apply_lora_from_file instead");
LLAMA_API int llama_model_apply_lora_from_file(
const struct llama_model * model,
- const char * path_lora,
- const char * path_base_model,
- int n_threads);
+ const char * path_lora,
+ float scale,
+ const char * path_base_model,
+ int n_threads);
//
// KV cache
printf("GGML_N_THREADS = %d\n", n_threads);
}
- struct ggml_cgraph gf = ggml_build_forward (f);
- struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+ struct ggml_cgraph * gf = ggml_build_forward_ctx(ctx0, f);
+ struct ggml_cgraph * gb = ggml_new_graph(ctx0);
+ *gb = *gf;
+ ggml_build_backward_expand(ctx0, gf, gb, false);
- ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+ ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
- ggml_graph_reset (&gf);
+ ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
- ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+ ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
- // ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
- // ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot");
+ // ggml_graph_dump_dot(gf, NULL, "test-grad0-forward.dot");
+ // ggml_graph_dump_dot(gb, gf, "test-grad0-backward.dot");
for (int i = 0; i < nargs; ++i) {
const int nelements = ggml_nelements(x[i]);
const float xp = x0 + eps;
ggml_set_f32_1d(x[i], k, xp);
- ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+ ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
const double f0 = ggml_get_f32_1d(f, 0);
ggml_set_f32_1d(x[i], k, xm);
- ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+ ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
const double f1 = ggml_get_f32_1d(f, 0);
const double g0 = (f0 - f1)/(2.0*(double) eps);
ggml_set_f32_1d(x[i], k, x0);
// compute gradient using backward graph
- ggml_graph_reset (&gf);
+ ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
- ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+ ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
const double g1 = ggml_get_f32_1d(x[i]->grad, k);
int main(int argc, const char ** argv) {
struct ggml_init_params params = {
- /* .mem_size = */ 128*1024*1024,
+ /* .mem_size = */ 256*1024*1024,
/* .mem_buffer = */ NULL,
/* .no_alloc = */ false,
};
}
}
+ unsigned seed_iter = 1;
// original loop: 1000
int niter = 4;
niter = atoi(argv[1]);
}
for (int iter = 0; iter < niter; ++iter) {
+ srand(seed_iter);
+ seed_iter = rand();
+ unsigned seed = rand();
+
printf("test-grad0: iter:%d/%d\n", iter, niter);
struct ggml_context * ctx0 = ggml_init(params);
// add f32
{
+ srand(seed);
const int nargs = 2;
for (int ndims = 1; ndims <= 4; ++ndims) {
// add f16
{
+ srand(seed);
const int nargs = 2;
for (int ndims = 1; ndims <= 4; ++ndims) {
// sub
{
+ srand(seed);
const int nargs = 2;
for (int ndims = 1; ndims <= 4; ++ndims) {
// mul
{
+ srand(seed);
const int nargs = 2;
for (int ndims = 1; ndims <= 4; ++ndims) {
// div
{
+ srand(seed);
const int nargs = 2;
for (int ndims = 1; ndims <= 4; ++ndims) {
// sqr
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
// sqrt
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
// log
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
// sum
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
// sum_rows
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
// mean, not yet fully implemented
if(0)
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
// argmax
if (0)
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
// repeat
{
+ srand(seed);
int64_t ne2[4];
get_random_dims(ne2, 4);
// repeat back
{
+ srand(seed);
int64_t ne2[4];
get_random_dims(ne2, 4);
// sgn
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
// neg
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
// step
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
// tanh, not yet fully implemented
if(0)
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
// mul_mat
{
+ srand(seed);
const int nargs = 2;
- for (int ndims = 2; ndims <= 2; ++ndims) {
+ for (int ndims = 2; ndims <= 4; ++ndims) {
+ int max_nrep = (ndims >= 3) ? 2 : 1;
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
- {
- int64_t ne2[4];
- get_random_dims(ne2, 4);
- ne2[0] = ne[0];
- x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
- }
+ for (int nrep2 = 1; nrep2 < max_nrep; ++nrep2) {
+ for (int nrep3 = 1; nrep3 < max_nrep; ++nrep3) {
+ {
+ int64_t ne2[4];
+ get_random_dims(ne2, 4);
+ ne2[0] = ne[0];
+ ne2[2] = nrep2 * ne[2];
+ ne2[3] = nrep3 * ne[3];
+ x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+ }
- ggml_set_param(ctx0, x[0]);
- ggml_set_param(ctx0, x[1]);
+ ggml_set_param(ctx0, x[0]);
+ ggml_set_param(ctx0, x[1]);
- struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
- struct ggml_tensor * f = ggml_sum(ctx0, m);
+ struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
+ struct ggml_tensor * f = ggml_sum(ctx0, m);
- GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
+ GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
- check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
- check_mat_mul(m, x[1], x[0]);
+ check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+ if (ndims == 2) {
+ // check_mat_mul does not support ndims > 2
+ check_mat_mul(m, x[1], x[0]);
+ }
+ }
+ }
}
}
// elu, not yet fully implemented
if(0)
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
// relu
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
// gelu, not yet fully implemented
if(0)
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
// silu
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
// rms_norm
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
// scale
{
+ srand(seed);
const int nargs = 2;
int64_t ne2[4];
// cpy f32
{
+ srand(seed);
const int nargs = 2;
for (int ndims = 1; ndims <= 2; ++ndims) {
// cpy f16
{
+ srand(seed);
const int nargs = 2;
for (int ndims = 1; ndims <= 2; ++ndims) {
// reshape (1d->nd)
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
// reshape (nd->1d)
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
// acc 1d
{
+ srand(seed);
int64_t ne2[4] = { 1, 1, 1, 1 };
const int nargs = 2;
// acc 2d
{
+ srand(seed);
int64_t ne2[4] = { 1, 1, 1, 1 };
int64_t max_offsets[4] = { 0, 0, 0, 0 };
int64_t offsets[4] = { 0, 0, 0, 0 };
// acc 3d
{
+ srand(seed);
int64_t ne2[4] = { 1, 1, 1, 1 };
int64_t max_offsets[4] = { 0, 0, 0, 0 };
int64_t offsets[4] = { 0, 0, 0, 0 };
// acc 4d
{
+ srand(seed);
int64_t ne2[4] = { 1, 1, 1, 1 };
int64_t max_offsets[4] = { 0, 0, 0, 0 };
int64_t offsets[4] = { 0, 0, 0, 0 };
// set_1d
{
+ srand(seed);
int64_t ne2[4];
const int nargs = 2;
// set_2d
{
+ srand(seed);
int64_t ne2[4];
int64_t max_offsets[4] = { 0, 0, 0, 0 };
int64_t offsets[4] = { 0, 0, 0, 0 };
// view_1d
{
+ srand(seed);
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
// view_2d
{
+ srand(seed);
int64_t ne2[4];
int64_t nb2[4];
// view_3d
{
+ srand(seed);
int64_t ne2[4] = {1,1,1,1};
int64_t nb2[4] = {0,0,0,0};
// permute
{
+ srand(seed);
int64_t ne2[4];
const int nargs = 1;
// transpose
{
+ srand(seed);
int64_t ne2[4];
const int nargs = 1;
// get_rows
{
+ srand(seed);
int64_t ne2[4] = {ne[0], ne[1], 1, 1};
int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
const int nargs = 1;
// diag_mask_inf
{
+ srand(seed);
const int nargs = 1;
const int ndims = 2;
// diag_mask_zero
{
+ srand(seed);
const int nargs = 1;
const int ndims = 2;
// softmax
{
+ srand(seed);
const int nargs = 1;
int64_t ne2[4];
ggml_new_f32(ctx0, eps))));
check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);
+ // NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf.
+ // this may result in different gradients too finite differences.
+ // when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause.
+ // if only the table lookup causes gradients to differ this is acceptable.
}
}
// cross_entropy_loss
{
+ srand(seed);
const int nargs = 1;
int64_t ne2[4];
// rope f32
{
+ srand(seed);
const int nargs = 1;
int64_t ne2[4];
// rope f16
{
+ srand(seed);
const int nargs = 1;
int64_t ne2[4];
// flash_attn f32
{
+ srand(seed);
const int nargs = 3;
int64_t ne2[4];
for (int masked = 0; masked <= 1; ++masked) {
for (int ndims = 2; ndims <= 4; ++ndims) {
- int64_t neq[4] = { D, N, B, ne[3] };
- int64_t nek[4] = { D, M, B, ne[3] };
- int64_t nev[4] = { M, D, B, ne[3] };
- if (ndims == 2) {
- neq[2] = 1; neq[3] = 1;
- nek[2] = 1; nek[3] = 1;
- nev[2] = 1; nev[3] = 1;
- } else if (ndims == 3) {
- neq[3] = 1;
- nek[3] = 1;
- nev[3] = 1;
- }
- x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
- x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
- x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
- ggml_set_param(ctx0, x[0]);
- ggml_set_param(ctx0, x[1]);
- ggml_set_param(ctx0, x[2]);
+ int max_nrep = (ndims >= 3) ? 2 : 1;
+ for (int nrep = 1; nrep < max_nrep; ++nrep) {
+ int64_t neq[4] = { D, N, B*nrep, ne[3] };
+ int64_t nek[4] = { D, M, B, ne[3] };
+ int64_t nev[4] = { M, D, B, ne[3] };
+ if (ndims == 2) {
+ neq[2] = 1; neq[3] = 1;
+ nek[2] = 1; nek[3] = 1;
+ nev[2] = 1; nev[3] = 1;
+ } else if (ndims == 3) {
+ neq[3] = 1;
+ nek[3] = 1;
+ nev[3] = 1;
+ }
+ x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
+ x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
+ x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
+ ggml_set_param(ctx0, x[0]);
+ ggml_set_param(ctx0, x[1]);
+ ggml_set_param(ctx0, x[2]);
- struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
- check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
+ check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
+ }
}
}
}
// flash_attn f16, not yet fully implemented
if(0)
{
+ srand(seed);
const int nargs = 3;
int64_t ne2[4];