]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert-llama2c-to-ggml : enable conversion of GQA models (#6237)
authorfraxy-v <redacted>
Fri, 22 Mar 2024 18:49:06 +0000 (20:49 +0200)
committerGitHub <redacted>
Fri, 22 Mar 2024 18:49:06 +0000 (20:49 +0200)
* convert-llama2c-to-ggml: enable conversion of multiqueries, #5608

* add test in build action

* Update build.yml

* Update build.yml

* Update build.yml

* gg patch

.github/workflows/build.yml
examples/convert-llama2c-to-ggml/README.md
examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp

index 7711bd8d8c470ffef2aa025385b336035f324369..bf42df8fe58d42589e940998bec7703f0c9737d7 100644 (file)
@@ -225,6 +225,17 @@ jobs:
           cd build
           ctest -L main --verbose --timeout 900
 
+      - name: Test llama2c conversion
+        id: llama2c_test
+        run: |
+          cd build
+          echo "Fetch tokenizer"
+          wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/tok512.bin
+          echo "Fetch llama2c model"
+          wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/stories260K.bin
+          ./bin/convert-llama2c-to-ggml --copy-vocab-from-model ./tok512.bin --llama2c-model stories260K.bin --llama2c-output-model stories260K.gguf
+          ./bin/main -m stories260K.gguf -p "One day, Lily met a Shoggoth" -n 500 -c 256
+
 #  ubuntu-latest-cmake-sanitizer:
 #    runs-on: ubuntu-latest
 #
index 0f37d295bd9ee440fe7135f725d260084c3a34b0..6da2d7e1809a94093f91587f38cd53433dca73f9 100644 (file)
@@ -21,6 +21,8 @@ An example command using a model from [karpathy/tinyllamas](https://huggingface.
 
 `$ ./convert-llama2c-to-ggml --copy-vocab-from-model llama-2-7b-chat.gguf.q2_K.bin --llama2c-model stories42M.bin --llama2c-output-model stories42M.gguf.bin`
 
+Note: The vocabulary for `stories260K.bin` should be its own tokenizer `tok512.bin` found in [karpathy/tinyllamas/stories260K](https://huggingface.co/karpathy/tinyllamas/tree/main/stories260K).
+
 Now you can use the model with a command like:
 
 `$ ./main -m stories42M.gguf.bin -p "One day, Lily met a Shoggoth" -n 500 -c 256`
index 8209dcb642bee9afcfd2d92af255e256d27200d1..6b5c6653011b1a1d09c4bbc0af709383bbad66a3 100644 (file)
@@ -1,6 +1,7 @@
 #include "ggml.h"
 #include "llama.h"
 #include "common.h"
+#include "log.h"
 
 #include <unordered_map>
 #include <vector>
@@ -78,111 +79,101 @@ typedef struct {
 
 struct TransformerWeights {
     // token embedding table
-    float* token_embedding_table;    // (vocab_size, dim)
+    std::vector<float> token_embedding_table;    // (vocab_size, dim)
     // weights for rmsnorms
-    float* rms_att_weight; // (layer, dim) rmsnorm weights
-    float* rms_ffn_weight; // (layer, dim)
+    std::vector<float> rms_att_weight; // (layer, dim) rmsnorm weights
+    std::vector<float> rms_ffn_weight; // (layer, dim)
     // weights for matmuls
-    float* wq; // (layer, dim, dim)
-    float* wk; // (layer, dim, dim)
-    float* wv; // (layer, dim, dim)
-    float* wo; // (layer, dim, dim)
+    std::vector<float> wq; // (layer, dim, dim)
+    std::vector<float> wk; // (layer, dim, dim)
+    std::vector<float> wv; // (layer, dim, dim)
+    std::vector<float> wo; // (layer, dim, dim)
     // weights for ffn
-    float* w1; // (layer, hidden_dim, dim)
-    float* w2; // (layer, dim, hidden_dim)
-    float* w3; // (layer, hidden_dim, dim)
+    std::vector<float> w1; // (layer, hidden_dim, dim)
+    std::vector<float> w2; // (layer, dim, hidden_dim)
+    std::vector<float> w3; // (layer, hidden_dim, dim)
     // final rmsnorm
-    float* rms_final_weight; // (dim,)
+    std::vector<float> rms_final_weight; // (dim,)
     // freq_cis for RoPE relatively positional embeddings
-    // float* freq_cis_real; // (seq_len, dim/2)
-    // float* freq_cis_imag; // (seq_len, dim/2)
+    // std::vector<float> freq_cis_real; // (seq_len, dim/2)
+    // std::vector<float> freq_cis_imag; // (seq_len, dim/2)
     // (optional) classifier weights for the logits, on the last layer
-    float* wcls;
-
-    ~TransformerWeights() {
-        delete[] token_embedding_table;
-        delete[] rms_att_weight;
-        delete[] rms_ffn_weight;
-        delete[] wq;
-        delete[] wk;
-        delete[] wv;
-        delete[] wo;
-        delete[] w1;
-        delete[] w2;
-        delete[] w3;
-        delete[] rms_final_weight;
-        delete[] wcls;
-    }
+    std::vector<float> wcls;
 };
 
-static void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) {
-    // we calloc instead of malloc to keep valgrind happy
-    w->token_embedding_table = new float[p->vocab_size * p->dim]();
-    printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
+static void alloc_weights(TransformerWeights * w, const Config * p, bool shared_weights) {
+    const int n_multiqueries = p->n_kv_heads <= 0 || p->n_kv_heads >= p->n_heads ? 1 : p->n_heads / p->n_kv_heads;
+    try {
+        w->token_embedding_table.resize(p->vocab_size * p->dim);
+        LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
 
-    w->rms_att_weight = new float[p->n_layers * p->dim]();
-    printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->rms_att_weight\n",__func__,p->n_layers, p->dim, p->n_layers * p->dim);
+        w->rms_att_weight.resize(p->n_layers * p->dim);
+        LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_att_weight\n",__func__,p->n_layers, p->dim, p->n_layers * p->dim);
 
-    w->rms_ffn_weight = new float[p->n_layers * p->dim]();
-    printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->rms_ffn_weight\n",__func__,p->n_layers , p->dim, p->n_layers * p->dim);
+        w->rms_ffn_weight.resize(p->n_layers * p->dim);
+        LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_ffn_weight\n",__func__,p->n_layers , p->dim, p->n_layers * p->dim);
 
-    w->wq = new float[p->n_layers * p->dim * p->dim]();
-    printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->wq\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim);
+        w->wq.resize(p->n_layers * p->dim * p->dim);
+        LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wq\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim);
 
-    w->wk = new float[p->n_layers * p->dim * p->dim]();
-    printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->wk\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim);
+        w->wk.resize(p->n_layers * p->dim * p->dim / n_multiqueries);
+        LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wk\n",__func__,p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries);
 
-    w->wv = new float[p->n_layers * p->dim * p->dim]();
-    printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->wv\n",__func__, p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim);
+        w->wv.resize(p->n_layers * p->dim * p->dim / n_multiqueries);
+        LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wv\n",__func__, p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries);
 
-    w->wo = new float[p->n_layers * p->dim * p->dim]();
-    printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->wo\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim);
+        w->wo.resize(p->n_layers * p->dim * p->dim);
+        LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wo\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim);
 
-    w->w1 = new float[p->n_layers * p->hidden_dim * p->dim]();
-    printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->w1\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim);
+        w->w1.resize(p->n_layers * p->hidden_dim * p->dim);
+        LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w1\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim);
 
-    w->w2 = new float[p->n_layers * p->hidden_dim * p->dim]();
-    printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->w2\n",__func__,p->n_layers, p->dim, p->hidden_dim, p->n_layers * p->hidden_dim * p->dim);
+        w->w2.resize(p->n_layers * p->hidden_dim * p->dim);
+        LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w2\n",__func__,p->n_layers, p->dim, p->hidden_dim, p->n_layers * p->hidden_dim * p->dim);
 
-    w->w3 = new float[p->n_layers * p->hidden_dim * p->dim]();
-    printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->w3\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim);
+        w->w3.resize(p->n_layers * p->hidden_dim * p->dim);
+        LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w3\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim);
 
-    w->rms_final_weight = new float[p->dim]();
-    printf("[%s:AK] Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim);
+        w->rms_final_weight.resize(p->dim);
+        LOG("%s: Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim);
 
-    if (shared_weights) {
-        w->wcls = NULL;
-    } else {
-        w->wcls = new float[p->vocab_size * p->dim]();
-        printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
+        if (shared_weights) {
+            w->wcls = {};
+        } else {
+            w->wcls.resize(p->vocab_size * p->dim);
+            LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
+        }
+    }
+    catch (std::length_error &) {
+        die("Invalid configuration. Failed to allocate memory for weights");
     }
 }
 
-static int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shared_weights) {
-    if (fread(w->token_embedding_table, sizeof(float), p->vocab_size * p->dim, f) != static_cast<size_t>(p->vocab_size * p->dim)) return 1;
-    if (fread(w->rms_att_weight, sizeof(float), p->n_layers * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim)) return 1;
-    if (fread(w->wq, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1;
-    if (fread(w->wk, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1;
-    if (fread(w->wv, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1;
-    if (fread(w->wo, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1;
-    if (fread(w->rms_ffn_weight, sizeof(float), p->n_layers * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim)) return 1;
-    if (fread(w->w1, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->hidden_dim)) return 1;
-    if (fread(w->w2, sizeof(float), p->n_layers * p->hidden_dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->hidden_dim * p->dim)) return 1;
-    if (fread(w->w3, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->hidden_dim)) return 1;
-    if (fread(w->rms_final_weight, sizeof(float), p->dim, f) != static_cast<size_t>(p->dim)) return 1;
+static int checkpoint_init_weights(TransformerWeights * w, const Config * p, FILE * f, bool shared_weights) {
+    if (fread(w->token_embedding_table.data(), sizeof(float), w->token_embedding_table.size(), f) != w->token_embedding_table.size()) return 1;
+    if (fread(w->rms_att_weight.data(), sizeof(float), w->rms_att_weight.size(), f) != w->rms_att_weight.size()) return 1;
+    if (fread(w->wq.data(), sizeof(float), w->wq.size(), f) != w->wq.size()) return 1;
+    if (fread(w->wk.data(), sizeof(float), w->wk.size(), f) != w->wk.size()) return 1;
+    if (fread(w->wv.data(), sizeof(float), w->wv.size(), f) != w->wv.size()) return 1;
+    if (fread(w->wo.data(), sizeof(float), w->wo.size(), f) != w->wo.size()) return 1;
+    if (fread(w->rms_ffn_weight.data(), sizeof(float), w->rms_ffn_weight.size(), f) != w->rms_ffn_weight.size()) return 1;
+    if (fread(w->w1.data(), sizeof(float), w->w1.size(), f) != w->w1.size()) return 1;
+    if (fread(w->w2.data(), sizeof(float), w->w2.size(), f) != w->w2.size()) return 1;
+    if (fread(w->w3.data(), sizeof(float), w->w3.size(), f) != w->w3.size()) return 1;
+    if (fread(w->rms_final_weight.data(), sizeof(float), w->rms_final_weight.size(), f) != w->rms_final_weight.size()) return 1;
 
     // Skip freq_cis_real & freq_cis_imag
     int head_size = p->dim / p->n_heads;
     fseek(f, p->seq_len * head_size * sizeof(float), SEEK_CUR);
 
-    if (!shared_weights && fread(w->wcls, sizeof(float), p->vocab_size * p->dim, f) != static_cast<size_t>(p->vocab_size * p->dim)) return 1;
+    if (!shared_weights && fread(w->wcls.data(), sizeof(float), w->wcls.size(), f) != w->wcls.size()) return 1;
 
     // Check we didn't forget to read anything
     auto curr = ftell(f);
     fseek(f, 0, SEEK_END);
     auto end = ftell(f);
     if (curr != end) {
-        printf("Error: failed to read the checkpoint file to the end (curr = %ld, end =  %ld)\n", curr, end);
+        LOG("%s: Error: failed to read the checkpoint file to the end (curr = %ld, end =  %ld)\n", __func__, curr, end);
         return 1;
     }
 
@@ -190,20 +181,20 @@ static int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bo
 }
 
 static void print_sample_weights(TransformerWeights *w){
-    printf("----- Quick print of first of the weight vales of all the variables\n");
-    printf("%f\n", w->token_embedding_table[0]);
-    printf("%f\n", w->rms_att_weight[0]);
-    printf("%f\n", w->rms_ffn_weight[0]);
-
-    printf("%f\n", w->wq[0]);
-    printf("%f\n", w->wk[0]);
-    printf("%f\n", w->wv[0]);
-    printf("%f\n", w->wo[0]);
-    printf("%f\n", w->w1[0]);
-    printf("%f\n", w->w2[0]);
-    printf("%f\n", w->w3[0]);
-    printf("%f\n", w->rms_att_weight[0]);
-    if (w->wcls) printf("%f\n", w->wcls[0]);
+    LOG("----- Quick print of first of the weight vales of all the variables\n");
+    LOG("%f\n", w->token_embedding_table[0]);
+    LOG("%f\n", w->rms_att_weight[0]);
+    LOG("%f\n", w->rms_ffn_weight[0]);
+
+    LOG("%f\n", w->wq[0]);
+    LOG("%f\n", w->wk[0]);
+    LOG("%f\n", w->wv[0]);
+    LOG("%f\n", w->wo[0]);
+    LOG("%f\n", w->w1[0]);
+    LOG("%f\n", w->w2[0]);
+    LOG("%f\n", w->w3[0]);
+    LOG("%f\n", w->rms_att_weight[0]);
+    if (!w->wcls.empty()) LOG("%f\n", w->wcls[0]);
 }
 ////////////////////////////////////////////////////////////////////////////////////////////////////////////
 
@@ -225,14 +216,16 @@ struct llama_vocab {
 };
 
 struct my_llama_hparams {
-    uint32_t n_vocab = 32000;
-    uint32_t n_ctx   = 512;   // this is provided as user input?
-    uint32_t n_embd  = 4096;
-    uint32_t n_ff    = 11008;
-    uint32_t n_mult  = 4;
-    uint32_t n_head  = 32;
-    uint32_t n_layer = 32;
-    uint32_t n_rot   = 64;
+    uint32_t n_vocab   = 32000;
+    uint32_t n_ctx     = 512;   // this is provided as user input?
+    uint32_t n_embd    = 4096;
+    uint32_t n_ff      = 11008;
+    uint32_t n_mult    = 4;
+    uint32_t n_head    = 32;
+    uint32_t n_head_kv = 32;
+    uint32_t n_layer   = 32;
+    uint32_t n_rot     = 64;
+
     bool operator!=(const my_llama_hparams& other) const {
         return memcmp(this, &other, sizeof(my_llama_hparams));
     }
@@ -325,14 +318,30 @@ struct train_params {
 };
 
 static void print_params(struct my_llama_hparams * params) {
-    printf("%s: n_vocab: %u\n", __func__, params->n_vocab);
-    printf("%s: n_ctx:   %u\n", __func__, params->n_ctx);
-    printf("%s: n_embd:  %u\n", __func__, params->n_embd);
-    printf("%s: n_mult:  %u\n", __func__, params->n_mult);
-    printf("%s: n_head:  %u\n", __func__, params->n_head);
-    printf("%s: n_ff:    %u\n", __func__, params->n_ff);
-    printf("%s: n_layer: %u\n", __func__, params->n_layer);
-    printf("%s: n_rot:   %u\n", __func__, params->n_rot);
+    LOG("%s: n_vocab:   %u\n", __func__, params->n_vocab);
+    LOG("%s: n_ctx:     %u\n", __func__, params->n_ctx);
+    LOG("%s: n_embd:    %u\n", __func__, params->n_embd);
+    LOG("%s: n_mult:    %u\n", __func__, params->n_mult);
+    LOG("%s: n_head:    %u\n", __func__, params->n_head);
+    LOG("%s: n_head_kv: %u\n", __func__, params->n_head_kv);
+    LOG("%s: n_ff:      %u\n", __func__, params->n_ff);
+    LOG("%s: n_layer:   %u\n", __func__, params->n_layer);
+    LOG("%s: n_rot:     %u\n", __func__, params->n_rot);
+}
+
+static void print_tensor_info(const struct ggml_context * ctx) {
+    for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        LOG("%s: Allocating ", __func__);
+        int64_t total = 1;
+        int i = 0;
+        for (; i < ggml_n_dims(t); ++i) {
+            if (i > 0) LOG("x ");
+            LOG("[%" PRId64 "] ", t->ne[i]);
+            total *= t->ne[i];
+        }
+        if (i > 1) LOG("= [%" PRId64 "] ", total);
+        LOG("float space for %s\n", ggml_get_name(t));
+    }
 }
 
 static void init_model(struct my_llama_model * model) {
@@ -342,6 +351,8 @@ static void init_model(struct my_llama_model * model) {
     const uint32_t n_layer = hparams.n_layer;
     const uint32_t n_vocab = hparams.n_vocab;
 
+    const uint32_t n_multiqueries = hparams.n_head_kv <= 0 || hparams.n_head_kv >= hparams.n_head ? 1 : hparams.n_head / hparams.n_head_kv;
+
     const uint32_t n_ff = hparams.n_ff;
     struct ggml_context * ctx = model->ctx;
 
@@ -350,25 +361,8 @@ static void init_model(struct my_llama_model * model) {
     model->train_tokens = 0;
 
     model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
-    printf("[%s:GG] Allocating [%u] x [%u] = [%u] float space for model->tok_embeddings\n",__func__,n_embd , n_vocab, n_embd * n_vocab);
-
     model->norm           = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
-    printf("[%s:GG] Allocating [%u] float space for model->norm\n",__func__,n_embd);
-
     model->output         = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
-    printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for model->output\n",__func__,n_embd, n_vocab, n_embd * n_vocab);
-
-    // printing the per-layer allocations here so we dont print in the for loop.
-    printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wq for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
-    printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wk for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
-    printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wv for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
-    printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wo for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
-
-    printf("[%s:GG] Allocating [%u] float space for layer.ffn_norm for [%u] layers\n",__func__,n_embd, n_layer);
-
-    printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.w1 for [%u] layers\n",__func__, n_ff, n_embd, n_embd * n_ff, n_layer);
-    printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.w2 for [%u] layers\n",__func__, n_embd, n_ff, n_ff * n_embd, n_layer);
-    printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.w3 for [%u] layers\n",__func__, n_ff, n_embd, n_embd * n_ff, n_layer);
 
     ggml_set_name(model->tok_embeddings, "tok_embeddings.weight");
     ggml_set_name(model->norm,           "norm.weight");
@@ -383,8 +377,8 @@ static void init_model(struct my_llama_model * model) {
         layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
 
         layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
-        layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
-        layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
+        layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd / n_multiqueries);
+        layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd / n_multiqueries);
         layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
 
         layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
@@ -406,6 +400,8 @@ static void init_model(struct my_llama_model * model) {
         ggml_format_name(layer.w2, "%s.feed_forward.w2.weight", layers_i.c_str());
         ggml_format_name(layer.w3, "%s.feed_forward.w3.weight", layers_i.c_str());
     }
+
+    print_tensor_info(ctx);
 }
 
 static float get_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
@@ -421,9 +417,9 @@ static int32_t get_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
 static void print_row(struct ggml_tensor * probs, int i) {
     for (int k = 0; k < probs->ne[0]; ++k) {
         float p = get_f32_2d(probs, k, i);
-        printf(" %f", p);
+        LOG(" %f", p);
     }
-    printf("\n");
+    LOG("\n");
 }
 
 static void print_matrix(struct ggml_tensor * probs) {
@@ -431,33 +427,12 @@ static void print_matrix(struct ggml_tensor * probs) {
     for (int i = 0; i < probs->ne[1]; ++i) {
         for (int k = 0; k < probs->ne[0]; ++k) {
             float p = get_f32_2d(probs, k, i);
-            printf(" %.2f", p);
+            LOG(" %.2f", p);
         }
-        printf("\n");
+        LOG("\n");
     }
 }
 
-#ifdef __GNUC__
-#ifdef __MINGW32__
-__attribute__((format(gnu_printf, 1, 2)))
-#else
-__attribute__((format(printf, 1, 2)))
-#endif
-#endif
-static std::string format(const char * fmt, ...) {
-    va_list ap, ap2;
-    va_start(ap, fmt);
-    va_copy(ap2, ap);
-    int size = vsnprintf(NULL, 0, fmt, ap);
-    GGML_ASSERT(size >= 0 && size < INT_MAX);
-    std::vector<char> buf(size + 1);
-    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
-    GGML_ASSERT(size2 == size);
-    va_end(ap2);
-    va_end(ap);
-    return std::string(buf.data(), size);
-}
-
 struct llama_file {
     // use FILE * so we don't have to re-open the file to mmap
     FILE * fp;
@@ -549,8 +524,9 @@ static std::string llama_escape_whitespaces(const std::string & text) {
     return out.str();
 }
 
-static void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) {
+static void load_vocab(const char * filename, const Config * config, struct llama_vocab * vocab) {
     if (is_ggml_file(filename)) {
+        LOG("%s: Loading vocabulary from gguf file %s\n", __func__, filename);
         struct ggml_context * ctx_data = NULL;
 
         struct gguf_init_params params = {
@@ -578,6 +554,9 @@ static void load_vocab(const char *filename, Config *config, struct llama_vocab
         const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
 
         const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
+        if (n_vocab != static_cast<uint32_t>(config->vocab_size)) {
+            die_fmt("vocab size mismatch: (gguf) %u != (llama2c) %d", n_vocab, config->vocab_size);
+        }
 
         vocab->id_to_token.resize(n_vocab);
 
@@ -595,7 +574,7 @@ static void load_vocab(const char *filename, Config *config, struct llama_vocab
         gguf_free(ctx);
     } else {
         // assume llama2.c vocabulary
-        printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename);
+        LOG("%s: Assuming llama2.c vocabulary since %s is not a gguf file\n", __func__, filename);
         llama_file file(filename, "rb");
         if (!file.fp) {
             die_fmt("%s: %s", strerror(errno), filename);
@@ -638,38 +617,15 @@ static void load_vocab(const char *filename, Config *config, struct llama_vocab
 }
 
 static void convert_weights_ak_to_gg(struct ggml_tensor * gg_weights, const float * karpathy_weights) {
-    int ct;
-    switch (ggml_n_dims(gg_weights)) {
-        case 1:
-            ct = 0;
-            for (int i0 = 0; i0 < gg_weights->ne[0]; i0++){
-                float * ptr = (float *) ((char *) gg_weights->data + i0*gg_weights->nb[0]);
-                *ptr = karpathy_weights[ct];
-                ct++;
-            }
-            break;
-        case 2:
-            ct = 0;
-            for (int i1 = 0; i1 < gg_weights->ne[1]; i1++) {
-                for (int i0 = 0; i0 < gg_weights->ne[0]; i0++) {
-                    float * ptr = (float *) ((char *) gg_weights->data + i0*gg_weights->nb[0] + i1*gg_weights->nb[1]);
-                    *ptr = karpathy_weights[ct];
-                    ct++;
-                }
-            }
-            break;
-        case 3:
-            ct = 0;
-            for (int i2 = 0; i2 < gg_weights->ne[2]; i2++) {
-                for (int i1 = 0; i1 < gg_weights->ne[1]; i1++) {
-                    for (int i0 = 0; i0 < gg_weights->ne[0]; i0++) {
-                        float * ptr = (float *) ((char *) gg_weights->data + i0*gg_weights->nb[0] + i1*gg_weights->nb[1] + i2*gg_weights->nb[2]);
-                        *ptr = karpathy_weights[ct];
-                        ct++;
-                    }
-                }
-            }
-            break;
+    int size = 1;
+    for (int dim = 0; dim < ggml_n_dims(gg_weights); ++dim) {
+        size *= gg_weights->ne[dim];
+    }
+    for (int ct = 0; ct < size; ++ct) {
+        int64_t i0 = 0; int64_t i1 = 0;
+        int64_t i2 = 0; int64_t i3 = 0;
+        ggml_unravel_index(gg_weights, ct, &i0, &i1, &i2, &i3);
+        ggml_set_f32_nd(gg_weights, i0, i1, i2, i3, karpathy_weights[ct]);
     }
 }
 
@@ -679,16 +635,18 @@ static void save_as_llama_model(
     // convert AK weights into GG weights one by one.
     // w->token_embedding_table -> model->tok_embeddings
     // float*                   -> struct ggml_tensor
-    convert_weights_ak_to_gg(model->tok_embeddings, w->token_embedding_table);
-    convert_weights_ak_to_gg(model->output, w->wcls ? w->wcls : w->token_embedding_table);
+    convert_weights_ak_to_gg(model->tok_embeddings, w->token_embedding_table.data());
+    convert_weights_ak_to_gg(model->output, !w->wcls.empty() ? w->wcls.data() : w->token_embedding_table.data());
 
-    convert_weights_ak_to_gg(model->norm, w->rms_final_weight);
+    convert_weights_ak_to_gg(model->norm, w->rms_final_weight.data());
     //print_row(model->norm, 0);
 
     // for rms-att-weight
     int row_length = model->hparams.n_embd;
     int n_ff = model->hparams.n_ff;
 
+    const uint32_t n_multiqueries = model->hparams.n_head_kv <= 0 || model->hparams.n_head_kv >= model->hparams.n_head ? 1 : model->hparams.n_head / model->hparams.n_head_kv;
+
     for (uint32_t i = 0; i < model->hparams.n_layer; ++i){
         auto & layer = model->layers[i];
         // 1d
@@ -697,9 +655,10 @@ static void save_as_llama_model(
 
         // from 3d matrix layer x dim x dim to 2d matrix dim x dim
         convert_weights_ak_to_gg(layer.wq            , &w->wq[i*row_length*row_length]);
-        convert_weights_ak_to_gg(layer.wk            , &w->wk[i*row_length*row_length]);
-        convert_weights_ak_to_gg(layer.wv            , &w->wv[i*row_length*row_length]);
         convert_weights_ak_to_gg(layer.wo            , &w->wo[i*row_length*row_length]);
+        // from 3d matrix layer x dim x dim to 2d matrix dim x dim / n_multiqueries
+        convert_weights_ak_to_gg(layer.wk            , &w->wk[i*row_length*row_length/n_multiqueries]);
+        convert_weights_ak_to_gg(layer.wv            , &w->wv[i*row_length*row_length/n_multiqueries]);
 
         convert_weights_ak_to_gg(layer.w1            , &w->w1[i*row_length*n_ff]);
         convert_weights_ak_to_gg(layer.w2            , &w->w2[i*n_ff*row_length]);
@@ -736,8 +695,8 @@ static void save_as_llama_model(
     gguf_set_val_u32(ctx, KV_EMBEDDING_LENGTH, model->hparams.n_embd);
     gguf_set_val_u32(ctx, KV_FEED_FORWARD_LENGTH, model->hparams.n_ff);
     gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head);
-    // n_head_kv is optional, default to n_head
-    // gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT_KV, ...);
+    gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head);
+    gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT_KV, model->hparams.n_head_kv);
     gguf_set_val_u32(ctx, KV_BLOCK_COUNT, model->hparams.n_layer);
     gguf_set_val_u32(ctx, KV_ROPE_DIMENSION_COUNT, model->hparams.n_rot);
     gguf_set_val_f32(ctx, KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f);
@@ -789,12 +748,12 @@ static void save_as_llama_model(
 
 static struct train_params get_default_train_params() {
     struct train_params params;
-    params.fn_vocab_model    = "models/7B/ggml-model-f16.gguf";
+    params.fn_vocab_model          = "models/7B/ggml-model-f16.gguf";
     params.fn_llama2c_output_model = "ak_llama_model.bin";
-    params.fn_train_data     = "shakespeare.txt";
-    params.fn_checkpoint_in  = "checkpoint.bin";
-    params.fn_checkpoint_out = "checkpoint.bin";
-    params.fn_model_out      = "ggml-checkpoint-f32.bin";
+    params.fn_train_data           = "shakespeare.txt";
+    params.fn_checkpoint_in        = "checkpoint.bin";
+    params.fn_checkpoint_out       = "checkpoint.bin";
+    params.fn_model_out            = "ggml-checkpoint-f32.bin";
 
     params.seed       =   -1;
 
@@ -829,8 +788,8 @@ static struct train_params get_default_train_params() {
     params.adam_alpha        = 1e-3f;
     params.adam_decay        = 1e-3f;
 
-    params.mem_model_gb   = 2;
-    params.mem_compute_gb = 24;
+    params.mem_model_gb    = 2;
+    params.mem_compute_gb  = 24;
     params.mem_compute0_gb = 8;
     params.mem_compute1_gb = 2;
 
@@ -916,19 +875,30 @@ int main(int argc, char ** argv) {
     if (!params_parse(argc, argv, &params)) {
         return 1;
     }
+    log_set_target(stdout);
     Config config;
     TransformerWeights weights = {};
     {
-        FILE *file = fopen(params.fn_llama2c_model, "rb");
-        if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; }
+        LOG("%s: Loading llama2c model from %s\n", __func__, params.fn_llama2c_model);
+        FILE *file = fopen(params.fn_llama2c_model, "r");
+        if (!file) {
+            LOG("%s: Unable to open the checkpoint file %s!\n", __func__, params.fn_llama2c_model);
+            return 1;
+        }
         // read in the config header
-        if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
+        if (fread(&config, sizeof(Config), 1, file) != 1) {
+            LOG("%s: Unable to read llama2c config from %s!\n",__func__,params.fn_llama2c_model);
+            return 1;
+        }
         auto shared_weights = config.vocab_size > 0;
         config.vocab_size = abs(config.vocab_size);
 
         // read in the Transformer weights
-        malloc_weights(&weights, &config, shared_weights);
-        if(checkpoint_init_weights(&weights, &config, file, shared_weights)) { return 1; }
+        alloc_weights(&weights, &config, shared_weights);
+        if (checkpoint_init_weights(&weights, &config, file, shared_weights)) {
+            LOG("%s: Unable to initialize transformer weights from %s!",__func__,params.fn_llama2c_model);
+            return 1;
+        }
         fclose(file);
     }
 
@@ -936,15 +906,18 @@ int main(int argc, char ** argv) {
     load_vocab(params.fn_vocab_model, &config, &vocab);
 
     struct my_llama_model model;
-    model.hparams.n_vocab = config.vocab_size; //llama_n_vocab(lctx);
-    model.hparams.n_ctx   = params.n_ctx;
-    model.hparams.n_embd  = config.dim; //params.n_embd;
-    model.hparams.n_ff    = config.hidden_dim;
-    model.hparams.n_mult  = 32;//params.n_mult;
-    model.hparams.n_head  = config.n_heads; //params.n_head;
-    model.hparams.n_layer = config.n_layers; //params.n_layer;
-    model.hparams.n_rot   = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head);
+    model.hparams.n_vocab   = config.vocab_size; //llama_n_vocab(lctx);
+    model.hparams.n_ctx     = params.n_ctx;
+    model.hparams.n_embd    = config.dim; //params.n_embd;
+    model.hparams.n_ff      = config.hidden_dim;
+    model.hparams.n_mult    = 32;//params.n_mult;
+    model.hparams.n_head    = config.n_heads; //params.n_head;
+    model.hparams.n_head_kv = config.n_kv_heads;
+    model.hparams.n_layer   = config.n_layers; //params.n_layer;
+    model.hparams.n_rot     = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head);
+
     print_params(&model.hparams);
+
     struct ggml_init_params lcparams;
     lcparams.mem_size   = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb);
     lcparams.mem_buffer = NULL;
@@ -956,7 +929,7 @@ int main(int argc, char ** argv) {
     model.name = basename(params.fn_llama2c_model);
     save_as_llama_model(&vocab, &model, &weights, params.fn_llama2c_output_model);
 
-    printf("Saving llama.c model file %s in ggml format at %s\n", params.fn_llama2c_model, params.fn_llama2c_output_model);
+    LOG("%s: Saving llama.c model file %s in ggml format at %s\n", __func__, params.fn_llama2c_model, params.fn_llama2c_output_model);
 
     ggml_free(model.ctx);
     return 0;