]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llm : add Falcon support (#2717)
authorGeorgi Gerganov <redacted>
Wed, 23 Aug 2023 20:08:04 +0000 (23:08 +0300)
committerGitHub <redacted>
Wed, 23 Aug 2023 20:08:04 +0000 (23:08 +0300)
* llama : refactor GGUF constants into static maps

* llama : check if model architecture is known

* llama : refactor llama_model_load_internal()

* gguf : add KV constant maps

* llm : read arch-specific KVs

* convert : add dummy scores + types

* falcon : load tensor data (CPU only)

* llama : fix loading progress bar

* llama : add arch member to llama_model

* falcon : CPU inference working

* falcon : support non-40B models

* falcon : minor

* llama : minor updates

ggml-ci

* convert-falcon-hf-to-gguf.py : fix special token mapping

* llama.cpp : llama default UNK token = id 0

* llama.cpp : fix bpe tokenizer

* llama.cpp : fix the fix of bpe tokenizer

* ggml : pass eps to ggml_norm

* metal : implement RoPE (mode = 2) + avoid ggml_repeat

* ggml : ggml_repeat always creates new tensor

* falcon : copy-paste self-attention from LLaMA

* metal : print extra compute pipeline info

* falcon : minor changes (still chasing the Metal problem)

* llama.cpp : fix linefeed token

* metal : fix GELU kernel numerical stability by using precise::tanh

* metal : temporary workaround for the concurrency optimization bug

* falcon : add CUDA offloading (#2739)

* llama : better model naming and size reporting

* llama : prep new tokenizer support

* llama : advanced BPE tokenizer based on ggllm.cpp imlpementation

* llama : remove oboslete comment

ggml-ci

* common : remove obsolete BPE API + disable test-tokenizer-1

* llama : revert BPE special-case in llama_byte_to_token()

* cuda : add TODOs for RoPE NeoX implementation

* llama : default special tokens based on vocab type

* perplexity : add log for start of tokenization

---------

Co-authored-by: klosax <redacted>
Co-authored-by: slaren <redacted>
18 files changed:
common/common.cpp
common/common.h
convert-falcon-hf-to-gguf.py
convert.py
examples/main/main.cpp
examples/perplexity/perplexity.cpp
ggml-alloc.c
ggml-alloc.h
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h
gguf.py
llama.cpp
llama.h
tests/CMakeLists.txt
tests/test-tokenizer-1.cpp

index 88a962ae385de5f5d918dd81412465edc6554b55..53002ba306b572ca9d11609f3aceb063f51575a1 100644 (file)
@@ -744,35 +744,3 @@ std::string llama_token_to_str(const struct llama_context * ctx, llama_token tok
 
     return std::string(result.data(), result.size());
 }
-
-std::vector<llama_token> llama_tokenize_bpe(
-        struct llama_context * ctx,
-           const std::string & text,
-                        bool   add_bos) {
-    int n_tokens = text.length() + add_bos;
-    std::vector<llama_token> result(n_tokens);
-    n_tokens = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
-    if (n_tokens < 0) {
-        result.resize(-n_tokens);
-        int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
-        GGML_ASSERT(check == -n_tokens);
-    } else {
-        result.resize(n_tokens);
-    }
-    return result;
-}
-
-std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) {
-    std::vector<char> result(8, 0);
-    const int n_tokens = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
-    if (n_tokens < 0) {
-        result.resize(-n_tokens);
-        const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
-        GGML_ASSERT(check == -n_tokens);
-    } else {
-        result.resize(n_tokens);
-    }
-
-    return std::string(result.data(), result.size());
-}
-
index d68a8ef88c97cc1ad924189abf3e7ee88423c45c..17d271e6750e27210151b06457fb3eb3eb500187 100644 (file)
@@ -120,15 +120,6 @@ std::vector<llama_token> llama_tokenize(
            const std::string & text,
                         bool   add_bos);
 
-std::vector<llama_token> llama_tokenize_bpe(
-        struct llama_context * ctx,
-           const std::string & text,
-                        bool   add_bos);
-
 std::string llama_token_to_str(
         const struct llama_context * ctx,
                        llama_token   token);
-
-std::string llama_token_to_str_bpe(
-    const struct llama_context * ctx,
-                   llama_token   token);
index 50069db56213cd22375904b428edcefdc536a4ee..43e208497a7bc14d5ada46a1902b1d4b24001afd 100755 (executable)
@@ -95,14 +95,17 @@ print("gguf: get model metadata")
 
 block_count = hparams["n_layer"]
 
-gguf_writer.add_name(last_dir)
+gguf_writer.add_name("Falcon")
 gguf_writer.add_context_length(2048) # not in config.json
 gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
 gguf_writer.add_embedding_length(hparams["hidden_size"])
 gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"])
 gguf_writer.add_block_count(block_count)
 gguf_writer.add_head_count(hparams["n_head"])
-if "n_head_kv" in hparams: gguf_writer.add_head_count_kv(hparams["n_head_kv"])
+if "n_head_kv" in hparams:
+    gguf_writer.add_head_count_kv(hparams["n_head_kv"])
+else:
+    gguf_writer.add_head_count_kv(1)
 gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
 
 # TOKENIZATION
@@ -110,6 +113,8 @@ gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
 print("gguf: get tokenizer metadata")
 
 tokens: List[str] = []
+scores: List[float] = []
+toktypes: List[int] = []
 merges: List[str] = []
 
 
@@ -153,41 +158,30 @@ if Path(dir_model + "/tokenizer.json").is_file():
             text = bytearray(pad_token)
 
         tokens.append(text)
+        scores.append(0.0)                      # dymmy
+        toktypes.append(gguf.TokenType.NORMAL)  # dummy
 
     gguf_writer.add_token_list(tokens)
+    gguf_writer.add_token_scores(scores)
+    gguf_writer.add_token_types(toktypes)
 
-    if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file():
-        print("gguf: get special token ids")
-
-        with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f:
-            tokenizer_config = json.load(f)
+print("gguf: get special token ids")
+# Look for special tokens in config.json
 
-        # find special token ids
+if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
+    gguf_writer.add_bos_token_id(hparams["bos_token_id"])
 
-        if "bos_token" in tokenizer_config:
-            for key in tokenizer_json["added_tokens"]:
-                if key["content"] == tokenizer_config["bos_token"]:
-                    gguf_writer.add_bos_token_id(key["id"])
+if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
+    gguf_writer.add_eos_token_id(hparams["eos_token_id"])
 
-        if "eos_token" in tokenizer_config:
-            for key in tokenizer_json["added_tokens"]:
-                if key["content"] == tokenizer_config["eos_token"]:
-                    gguf_writer.add_eos_token_id(key["id"])
+if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
+    gguf_writer.add_unk_token_id(hparams["unk_token_id"])
 
-        if "unk_token" in tokenizer_config:
-            for key in tokenizer_json["added_tokens"]:
-                if key["content"] == tokenizer_config["unk_token"]:
-                    gguf_writer.add_unk_token_id(key["id"])
+if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
+    gguf_writer.add_sep_token_id(hparams["sep_token_id"])
 
-        if "sep_token" in tokenizer_config:
-            for key in tokenizer_json["added_tokens"]:
-                if key["content"] == tokenizer_config["sep_token"]:
-                    gguf_writer.add_sep_token_id(key["id"])
-
-        if "pad_token" in tokenizer_config:
-            for key in tokenizer_json["added_tokens"]:
-                if key["content"] == tokenizer_config["pad_token"]:
-                    gguf_writer.add_pad_token_id(key["id"])
+if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
+    gguf_writer.add_pad_token_id(hparams["pad_token_id"])
 
 
 # TENSORS
@@ -195,8 +189,9 @@ if Path(dir_model + "/tokenizer.json").is_file():
 tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
 
 # params for qkv transform
-n_head = hparams["n_head"]
+n_head    = hparams["n_head"]
 n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1
+
 head_dim = hparams["hidden_size"] // n_head
 
 # tensor info
index a701ab41b436a7d9f97a92e604271499534cc10f..8d34d5f291ddb68dbc5561e86d1a2ce14af2248d 100755 (executable)
@@ -733,7 +733,11 @@ class OutputFile:
         self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
 
     def add_meta_arch(self, params: Params) -> None:
-        self.gguf.add_name                ("LLaMA")
+        ver = None
+        if (params.n_ctx == 4096):
+            ver = "v2"
+
+        self.gguf.add_name                ("LLaMA" if ver == None else "LLaMA " + ver)
         self.gguf.add_context_length      (params.n_ctx)
         self.gguf.add_embedding_length    (params.n_embd)
         self.gguf.add_block_count         (params.n_layer)
index 0a22f3c25ff4624a5d985a2bf874e14e49749173..1393f0b084a2116ee2dff3a806df1a1a258e7010 100644 (file)
@@ -43,7 +43,7 @@ static bool is_interacting = false;
 void sigint_handler(int signo) {
     if (signo == SIGINT) {
         if (!is_interacting) {
-            is_interacting=true;
+            is_interacting = true;
         } else {
             console::cleanup();
             printf("\n");
@@ -189,10 +189,12 @@ int main(int argc, char ** argv) {
         }
     }
 
+    const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+
     // tokenize the prompt
     std::vector<llama_token> embd_inp;
     if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
-        embd_inp = ::llama_tokenize(ctx, params.prompt, true);
+        embd_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
     } else {
         embd_inp = session_tokens;
     }
@@ -208,9 +210,9 @@ int main(int argc, char ** argv) {
     int original_prompt_len = 0;
     if (ctx_guidance) {
         params.cfg_negative_prompt.insert(0, 1, ' ');
-        guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
+        guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, is_spm);
 
-        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
+        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
         original_prompt_len = original_inp.size();
         guidance_offset = (int)guidance_inp.size() - original_prompt_len;
     }
@@ -257,8 +259,8 @@ int main(int argc, char ** argv) {
     }
 
     // prefix & suffix for instruct mode
-    const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
-    const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
+    const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", is_spm);
+    const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n",    false);
 
     // in instruct mode, we inject a prefix and a suffix to each input by the user
     if (params.instruct) {
index e89725efc3db6bedbec50c6984553b4c8bfb9904..a7bd9db2a3fd323286cad1fc667c5d3367aad22f 100644 (file)
@@ -28,7 +28,6 @@ std::vector<float> softmax(const std::vector<float>& logits) {
 }
 
 void perplexity_v2(llama_context * ctx, const gpt_params & params) {
-
     // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
     // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
     // Output: `perplexity: 13.5106 [114/114]`
@@ -38,7 +37,13 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
         fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
         return;
     }
-    auto tokens = ::llama_tokenize(ctx, params.prompt, true);
+
+    const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+    const bool add_bos = is_spm;
+
+    fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
+
+    auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
 
     const int calc_chunk = params.n_ctx;
 
@@ -86,7 +91,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
             const auto token_org = tokens[batch_start];
 
             // add BOS token for the first batch of each chunk
-            if (j == 0) {
+            if (add_bos && j == 0) {
                 tokens[batch_start] = llama_token_bos(ctx);
             }
 
@@ -136,7 +141,6 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
 }
 
 void perplexity(llama_context * ctx, const gpt_params & params) {
-
     if (params.ppl_stride > 0) {
         perplexity_v2(ctx, params);
         return;
@@ -146,7 +150,13 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
     // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
     // Output: `perplexity: 13.5106 [114/114]`
     // BOS tokens will be added for each chunk before eval
-    auto tokens = ::llama_tokenize(ctx, params.prompt, true);
+
+    const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+    const bool add_bos = is_spm;
+
+    fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
+
+    auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
 
     const int n_chunk_max = tokens.size() / params.n_ctx;
 
@@ -177,7 +187,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
             const auto token_org = tokens[batch_start];
 
             // add BOS token for the first batch of each chunk
-            if (j == 0) {
+            if (add_bos && j == 0) {
                 tokens[batch_start] = llama_token_bos(ctx);
             }
 
@@ -295,8 +305,10 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     size_t hs_task_count = prompt_lines.size()/6;
     fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
 
+    const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+
     // This is needed as usual for LLaMA models
-    bool prepend_bos = true;
+    const bool add_bos = is_spm;
 
     // Number of tasks to use when computing the score
     if ( params.hellaswag_tasks < hs_task_count  ) {
@@ -352,14 +364,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     std::vector<float> tok_logits(n_vocab);
 
     for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
-
         // Tokenize the context to count tokens
-        std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
+        std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
         size_t context_size = context_embd.size();
 
         // Do the 1st ending
         // In this case we include the context when evaluating
-        auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
+        auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
         auto query_size = query_embd.size();
         //printf("First query: %d\n",(int)query_size);
 
index f06f9a3c1d97b97c700631fe810f62fec280310a..547ec0399fdb5b174dd865dab2ffe9284c6263a9 100644 (file)
@@ -238,7 +238,7 @@ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_t
     alloc->n_free_blocks++;
 }
 
-void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n) {
+void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
     int pos = 0;
     for (int i = 0; i < n; i++) {
         if (list[i] != -1) {
@@ -547,7 +547,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
                         struct ggml_tensor * view_src = get_view_source(parent);
                         struct hash_node * view_src_hn = hash_get(ht, view_src);
                         view_src_hn->n_views -= 1;
-                        AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views);
+                        AT_PRINTF("view_src %s\n", view_src->name);
                         if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
                             ggml_allocator_free_tensor(alloc, view_src);
                         }
index 14a4350ac2e968f5455632f2b7ad5dba5f4a2d8e..9559da75871a608fb29f08edebf860674295e043 100644 (file)
@@ -12,7 +12,7 @@ GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
 
 // tell the allocator to parse nodes following the order described in the list
 // you should call this if your graph are optimized to execute out-of-order
-GGML_API void   ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n);
+GGML_API void   ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
 
 GGML_API void   ggml_allocr_free(struct ggml_allocr * alloc);
 GGML_API bool   ggml_allocr_is_measure(struct ggml_allocr * alloc);
index 70a950bb58b9b0b09356d8e9a226a9f339b91067..868b7a7b905a24ea7ea27a856e3834ee6a6c74a8 100644 (file)
@@ -3907,6 +3907,29 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
     dst[i + 1] = x0*sin_theta + x1*cos_theta;
 }
 
+// TODO: this implementation is wrong!
+//static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
+//                                const float p_delta, const int p_delta_rows, const float theta_scale) {
+//    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+//
+//    if (col >= ncols) {
+//        return;
+//    }
+//
+//    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+//    const int i = row*ncols + col/2;
+//
+//    const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
+//    const float sin_theta = sinf(theta);
+//    const float cos_theta = cosf(theta);
+//
+//    const float x0 = x[i + 0];
+//    const float x1 = x[i + ncols/2];
+//
+//    dst[i + 0]       = x0*cos_theta - x1*sin_theta;
+//    dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
+//}
+
 static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int half_n_dims = ncols/4;
@@ -5515,7 +5538,8 @@ inline void ggml_cuda_op_rope(
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
-    const bool is_glm = mode & 4;
+    const bool is_neox = mode & 2;
+    const bool is_glm  = mode & 4;
 
     // compute
     if (is_glm) {
@@ -5523,6 +5547,9 @@ inline void ggml_cuda_op_rope(
         const float id_p = min(p, n_ctx - 2.f);
         const float block_p = max(p - (n_ctx - 2.f), 0.f);
         rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
+    } else if (is_neox) {
+        GGML_ASSERT(false && "RoPE NeoX not implemented yet");
+#pragma message("TODO: implement RoPE NeoX for CUDA")
     } else {
         const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
index 835c5f297cf95d0b6cff9a16eab111691076f3ad..969cf7daa74c5e8d47a77dd1f83b0010b858b462 100644 (file)
@@ -167,7 +167,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #define GGML_METAL_ADD_KERNEL(name) \
         ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
         ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
-        fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
+        fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
+                (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
+                (int) ctx->pipeline_##name.threadExecutionWidth); \
         if (error) { \
             fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
             return NULL; \
@@ -218,12 +220,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #undef GGML_METAL_ADD_KERNEL
     }
 
-    fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
-    fprintf(stderr, "%s: hasUnifiedMemory             = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+    fprintf(stderr, "%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+    fprintf(stderr, "%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
     if (ctx->device.maxTransferRate != 0) {
-        fprintf(stderr, "%s: maxTransferRate              = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
+        fprintf(stderr, "%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
     } else {
-        fprintf(stderr, "%s: maxTransferRate              = built-in GPU\n", __func__);
+        fprintf(stderr, "%s: maxTransferRate               = built-in GPU\n", __func__);
     }
 
     return ctx;
@@ -537,8 +539,8 @@ void ggml_metal_graph_compute(
 
             id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
 
-            const int node_start =                                  (cb_idx + 0) * n_nodes_per_cb;
-            const int node_end   = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+            const int node_start =                                      (cb_idx + 0) * n_nodes_per_cb;
+            const int node_end   = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
 
             for (int ind = node_start; ind < node_end; ++ind) {
                 const int i = has_concur ? ctx->concur_list[ind] : ind;
@@ -744,32 +746,31 @@ void ggml_metal_graph_compute(
                                 [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
                                 ne00%32 == 0 &&
                                 ne11 > 1) {
-                                    switch (src0->type) {
-                                        case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
-                                        case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
-                                        case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
-                                        case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
-                                        case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
-                                        case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
-                                        case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
-                                        case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
-                                        default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
-                                    }
-                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                                    [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                                    [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                                    [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
-                                    [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
-                                    [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
-                                    [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
-                                    [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
-                                    [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
-                                    [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                                    [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                                switch (src0->type) {
+                                    case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
+                                    case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
+                                    case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
+                                    case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
+                                    case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
+                                    case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
+                                    case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
+                                    case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
+                                    default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
                                 }
-                            else {
+                                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
+                                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
+                                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
+                                [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
+                                [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4];
+                                [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5];
+                                [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6];
+                                [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
+                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:8];
+                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:9];
+                                [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:10];
+                                [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                            } else {
                                 int nth0 = 32;
                                 int nth1 = 1;
 
@@ -868,24 +869,24 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
                                 [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
-                                [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
+                                [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
                                     src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #else
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #endif
                                 }
                                 else if (src0t == GGML_TYPE_Q5_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q6_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
                                     [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -938,16 +939,17 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_NORM:
                         {
-                            const float eps = 1e-5f;
+                            float eps;
+                            memcpy(&eps, dst->op_params, sizeof(float));
 
                             const int nth = 256;
 
                             [encoder setComputePipelineState:ctx->pipeline_norm];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
-                            [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
+                            [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
                             [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
@@ -990,7 +992,9 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
                             [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
                             [encoder setBytes:&m0  length:sizeof(    float) atIndex:18];
+
                             const int nth = 32;
+
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     case GGML_OP_ROPE:
@@ -1005,8 +1009,8 @@ void ggml_metal_graph_compute(
                             memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
                             [encoder setComputePipelineState:ctx->pipeline_rope];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
                             [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
@@ -1057,24 +1061,24 @@ void ggml_metal_graph_compute(
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
-                            [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
-                            [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
-                            [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
-                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
-                            [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
-                            [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
-                            [encoder setBytes:&ne0  length:sizeof( int64_t) atIndex:10];
-                            [encoder setBytes:&ne1  length:sizeof( int64_t) atIndex:11];
-                            [encoder setBytes:&ne2  length:sizeof( int64_t) atIndex:12];
-                            [encoder setBytes:&ne3  length:sizeof( int64_t) atIndex:13];
-                            [encoder setBytes:&nb0  length:sizeof(uint64_t) atIndex:14];
-                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
-                            [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
-                            [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
index ce3541f4bb55f6a06178975ad6b0fa51a2435906..7bc3fdf371897311707bb76d328ee7b22de5b2a1 100644 (file)
@@ -87,7 +87,12 @@ kernel void kernel_gelu(
     device       float * dst,
     uint tpig[[thread_position_in_grid]]) {
     float x = src0[tpig];
-    dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+
+    // BEWARE !!!
+    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+    // This was observed with Falcon 7B and 40B models
+    //
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
 }
 
 kernel void kernel_soft_max(
@@ -571,7 +576,25 @@ kernel void kernel_rope(
             dst_data[1] = x0*sin_theta + x1*cos_theta;
         }
     } else {
-        // TODO: implement
+        for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+            for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                const float cos_theta = cos(theta);
+                const float sin_theta = sin(theta);
+
+                theta *= theta_scale;
+
+                const int64_t i0 = ib*n_dims + ic/2;
+
+                device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                device       float * dst_data  = (device float *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                const float x0 = src[0];
+                const float x1 = src[n_dims/2];
+
+                dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+            }
+        }
     }
 }
 
diff --git a/ggml.c b/ggml.c
index dffb977313584e5efd8b75d34e29e3bbe61724a1..8cb5c404f285da3aea58bf9b6380e62ee929d6ca 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -3554,9 +3554,9 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) {
 inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 
-static const float GELU_COEF_A    = 0.044715f;
-static const float GELU_QUICK_COEF    = -1.702f;
-static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+static const float GELU_COEF_A     = 0.044715f;
+static const float GELU_QUICK_COEF = -1.702f;
+static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
 
 inline static float ggml_gelu_f32(float x) {
     return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
@@ -5555,10 +5555,6 @@ struct ggml_tensor * ggml_repeat(
         is_node = true;
     }
 
-    if (ggml_are_same_shape(a, b) && !is_node) {
-        return a;
-    }
-
     struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
 
     result->op   = GGML_OP_REPEAT;
@@ -5789,6 +5785,7 @@ struct ggml_tensor * ggml_silu_back(
 static struct ggml_tensor * ggml_norm_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
+        float eps,
         bool inplace) {
     bool is_node = false;
 
@@ -5799,7 +5796,7 @@ static struct ggml_tensor * ggml_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    // TODO: maybe store epsilon here?
+    ggml_set_op_params(result, &eps, sizeof(eps));
 
     result->op   = GGML_OP_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5810,14 +5807,16 @@ static struct ggml_tensor * ggml_norm_impl(
 
 struct ggml_tensor * ggml_norm(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_norm_impl(ctx, a, false);
+        struct ggml_tensor  * a,
+        float eps) {
+    return ggml_norm_impl(ctx, a, eps, false);
 }
 
 struct ggml_tensor * ggml_norm_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_norm_impl(ctx, a, true);
+        struct ggml_tensor  * a,
+        float eps) {
+    return ggml_norm_impl(ctx, a, eps, true);
 }
 
 // ggml_rms_norm
@@ -10619,7 +10618,8 @@ static void ggml_compute_forward_norm_f32(
 
     GGML_TENSOR_UNARY_OP_LOCALS;
 
-    const float eps = 1e-5f; // TODO: make this a parameter
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
 
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -12537,7 +12537,7 @@ static void ggml_compute_forward_rope_f32(
                         dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
                     }
                 } else {
-                    // TODO: this is probably wrong, but I can't figure it out ..
+                    // TODO: this might be wrong for ne0 != n_dims - need double check
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
@@ -12666,7 +12666,7 @@ static void ggml_compute_forward_rope_f16(
                         dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
                     }
                 } else {
-                    // TODO: this is probably wrong, but I can't figure it out ..
+                    // TODO: this might be wrong for ne0 != n_dims - need double check
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
diff --git a/ggml.h b/ggml.h
index 3c48fd27fab39d2986f76f3fc4a3b1543a6cb915..421c0df60c579e2069f9dc5130dc3949d9ddee65 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -909,14 +909,15 @@ extern "C" {
             struct ggml_tensor  * b);
 
     // normalize along rows
-    // TODO: eps is hardcoded to 1e-5 for now
     GGML_API struct ggml_tensor * ggml_norm(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a,
+            float                 eps);
 
     GGML_API struct ggml_tensor * ggml_norm_inplace(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a,
+            float                 eps);
 
     GGML_API struct ggml_tensor * ggml_rms_norm(
             struct ggml_context * ctx,
diff --git a/gguf.py b/gguf.py
index 9421080b80528ba7abf81a7dab56d31d3715783d..5c37f0f0b78e021b278fd11e66094c7d0941ad55 100755 (executable)
--- a/gguf.py
+++ b/gguf.py
@@ -30,12 +30,12 @@ KEY_GENERAL_SOURCE_HF_REPO       = "general.source.hugginface.repository"
 KEY_GENERAL_FILE_TYPE            = "general.file_type"
 
 # LLM
-KEY_LLM_CONTEXT_LENGTH        = "{arch}.context_length"
-KEY_LLM_EMBEDDING_LENGTH      = "{arch}.embedding_length"
-KEY_LLM_BLOCK_COUNT           = "{arch}.block_count"
-KEY_LLM_FEED_FORWARD_LENGTH   = "{arch}.feed_forward_length"
-KEY_LLM_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
-KEY_LLM_TENSOR_DATA_LAYOUT    = "{arch}.tensor_data_layout"
+KEY_CONTEXT_LENGTH        = "{arch}.context_length"
+KEY_EMBEDDING_LENGTH      = "{arch}.embedding_length"
+KEY_BLOCK_COUNT           = "{arch}.block_count"
+KEY_FEED_FORWARD_LENGTH   = "{arch}.feed_forward_length"
+KEY_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
+KEY_TENSOR_DATA_LAYOUT    = "{arch}.tensor_data_layout"
 
 # attention
 KEY_ATTENTION_HEAD_COUNT        = "{arch}.attention.head_count"
@@ -583,7 +583,7 @@ class GGUFWriter:
         self.add_string(KEY_GENERAL_AUTHOR, author)
 
     def add_tensor_data_layout(self, layout: str):
-        self.add_string(KEY_LLM_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
+        self.add_string(KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
 
     def add_url(self, url: str):
         self.add_string(KEY_GENERAL_URL, url)
@@ -613,27 +613,27 @@ class GGUFWriter:
 
     def add_context_length(self, length: int):
         self.add_uint32(
-            KEY_LLM_CONTEXT_LENGTH.format(arch=self.arch), length)
+            KEY_CONTEXT_LENGTH.format(arch=self.arch), length)
 
     def add_embedding_length(self, length: int):
         self.add_uint32(
-            KEY_LLM_EMBEDDING_LENGTH.format(arch=self.arch), length)
+            KEY_EMBEDDING_LENGTH.format(arch=self.arch), length)
 
     def add_block_count(self, length: int):
         self.add_uint32(
-            KEY_LLM_BLOCK_COUNT.format(arch=self.arch), length)
+            KEY_BLOCK_COUNT.format(arch=self.arch), length)
 
     def add_feed_forward_length(self, length: int):
         self.add_uint32(
-            KEY_LLM_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
+            KEY_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
 
     def add_parallel_residual(self, use: bool):
         self.add_bool(
-            KEY_LLM_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
+            KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
 
     def add_tensor_data_layout(self, layout: str):
         self.add_string(
-            KEY_LLM_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
+            KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
 
     def add_head_count(self, count: int):
         self.add_uint32(
index fd8eaa1800bde19e74a3c45149b756b8f320ee1f..f2dc4da1db344221ad10749535128ac84d650a0d 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -72,6 +72,7 @@
 #include <numeric>
 #include <queue>
 #include <random>
+#include <regex>
 #include <sstream>
 #include <thread>
 #include <unordered_map>
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-// tensor names
-#define TN_TOKEN_EMBD  "token_embd.weight"
-#define TN_OUTPUT_NORM "output_norm.weight"
-#define TN_OUTPUT      "output.weight"
-#define TN_ATTN_NORM   "blk.%d.attn_norm.weight"
-#define TN_ATTN_Q      "blk.%d.attn_q.weight"
-#define TN_ATTN_K      "blk.%d.attn_k.weight"
-#define TN_ATTN_V      "blk.%d.attn_v.weight"
-#define TN_ATTN_OUTPUT "blk.%d.attn_output.weight"
-#define TN_FFN_NORM    "blk.%d.ffn_norm.weight"
-#define TN_FFN_GATE    "blk.%d.ffn_gate.weight"
-#define TN_FFN_DOWN    "blk.%d.ffn_down.weight"
-#define TN_FFN_UP      "blk.%d.ffn_up.weight"
-
 #ifdef __GNUC__
 #ifdef __MINGW32__
 #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
 //
 // logging
 //
+
 LLAMA_ATTRIBUTE_FORMAT(2, 3)
 static void llama_log_internal        (llama_log_level level, const char* format, ...);
 static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data);
@@ -119,6 +107,21 @@ static void llama_log_callback_default(llama_log_level level, const char * text,
 // helpers
 //
 
+static size_t utf8_len(char src) {
+    const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+    uint8_t highbits = static_cast<uint8_t>(src) >> 4;
+    return lookup[highbits];
+}
+
+void replace_all(std::string & s, const std::string & search, const std::string & replace) {
+    for (size_t pos = 0; ; pos += replace.length()) {
+        pos = s.find(search, pos);
+        if (pos == std::string::npos) break;
+        s.erase(pos, search.length());
+        s.insert(pos, replace);
+    }
+}
+
 static void zeros(std::ofstream & file, size_t n) {
     char zero = 0;
     for (size_t i = 0; i < n; ++i) {
@@ -142,6 +145,241 @@ static std::string format(const char * fmt, ...) {
     return std::string(buf.data(), size);
 }
 
+//
+// gguf constants (sync with gguf.py)
+//
+
+enum llm_arch {
+    LLM_ARCH_LLAMA,
+    LLM_ARCH_FALCON,
+    LLM_ARCH_GPT2,
+    LLM_ARCH_GPTJ,
+    LLM_ARCH_GPTNEOX,
+    LLM_ARCH_MPT,
+    LLM_ARCH_UNKNOWN,
+};
+
+static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
+    { LLM_ARCH_LLAMA,   "llama"   },
+    { LLM_ARCH_FALCON,  "falcon"  },
+    { LLM_ARCH_GPT2,    "gpt2"    },
+    { LLM_ARCH_GPTJ,    "gptj"    },
+    { LLM_ARCH_GPTNEOX, "gptneox" },
+    { LLM_ARCH_MPT,     "mpt"     },
+};
+
+enum llm_kv {
+    LLM_KV_GENERAL_ARCHITECTURE,
+    LLM_KV_GENERAL_QUANTIZATION_VERSION,
+    LLM_KV_GENERAL_ALIGNMENT,
+    LLM_KV_GENERAL_NAME,
+    LLM_KV_GENERAL_AUTHOR,
+    LLM_KV_GENERAL_URL,
+    LLM_KV_GENERAL_DESCRIPTION,
+    LLM_KV_GENERAL_LICENSE,
+    LLM_KV_GENERAL_SOURCE_URL,
+    LLM_KV_GENERAL_SOURCE_HF_REPO,
+
+    LLM_KV_CONTEXT_LENGTH,
+    LLM_KV_EMBEDDING_LENGTH,
+    LLM_KV_BLOCK_COUNT,
+    LLM_KV_FEED_FORWARD_LENGTH,
+    LLM_KV_USE_PARALLEL_RESIDUAL,
+    LLM_KV_TENSOR_DATA_LAYOUT,
+
+    LLM_KV_ATTENTION_HEAD_COUNT,
+    LLM_KV_ATTENTION_HEAD_COUNT_KV,
+    LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
+    LLM_KV_ATTENTION_CLAMP_KQV,
+    LLM_KV_ATTENTION_LAYERNORM_EPS,
+    LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
+
+    LLM_KV_ROPE_DIMENSION_COUNT,
+    LLM_KV_ROPE_SCALE_LINEAR,
+
+    LLM_KV_TOKENIZER_MODEL,
+    LLM_KV_TOKENIZER_LIST,
+    LLM_KV_TOKENIZER_TOKEN_TYPE,
+    LLM_KV_TOKENIZER_SCORES,
+    LLM_KV_TOKENIZER_MERGES,
+    LLM_KV_TOKENIZER_BOS_ID,
+    LLM_KV_TOKENIZER_EOS_ID,
+    LLM_KV_TOKENIZER_UNK_ID,
+    LLM_KV_TOKENIZER_SEP_ID,
+    LLM_KV_TOKENIZER_PAD_ID,
+    LLM_KV_TOKENIZER_HF_JSON,
+    LLM_KV_TOKENIZER_RWKV,
+};
+
+static std::map<llm_kv, std::string> LLM_KV_NAMES = {
+    { LLM_KV_GENERAL_ARCHITECTURE,          "general.architecture"         },
+    { LLM_KV_GENERAL_QUANTIZATION_VERSION,  "general.quantization_version" },
+    { LLM_KV_GENERAL_ALIGNMENT,             "general.alignment"            },
+    { LLM_KV_GENERAL_NAME,                  "general.name"                 },
+    { LLM_KV_GENERAL_AUTHOR,                "general.author"               },
+    { LLM_KV_GENERAL_URL,                   "general.url"                  },
+    { LLM_KV_GENERAL_DESCRIPTION,           "general.description"          },
+    { LLM_KV_GENERAL_LICENSE,               "general.license"              },
+    { LLM_KV_GENERAL_SOURCE_URL,            "general.source_url"           },
+    { LLM_KV_GENERAL_SOURCE_HF_REPO,        "general.source_hf_repo"       },
+
+    { LLM_KV_CONTEXT_LENGTH,                "%s.context_length"        },
+    { LLM_KV_EMBEDDING_LENGTH,              "%s.embedding_length"      },
+    { LLM_KV_BLOCK_COUNT,                   "%s.block_count"           },
+    { LLM_KV_FEED_FORWARD_LENGTH,           "%s.feed_forward_length"   },
+    { LLM_KV_USE_PARALLEL_RESIDUAL,         "%s.use_parallel_residual" },
+    { LLM_KV_TENSOR_DATA_LAYOUT,            "%s.tensor_data_layout"    },
+
+    { LLM_KV_ATTENTION_HEAD_COUNT,          "%s.attention.head_count"             },
+    { LLM_KV_ATTENTION_HEAD_COUNT_KV,       "%s.attention.head_count_kv"          },
+    { LLM_KV_ATTENTION_MAX_ALIBI_BIAS,      "%s.attention.max_alibi_bias"         },
+    { LLM_KV_ATTENTION_CLAMP_KQV,           "%s.attention.clamp_kqv"              },
+    { LLM_KV_ATTENTION_LAYERNORM_EPS,       "%s.attention.layer_norm_epsilon"     },
+    { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,   "%s.attention.layer_norm_rms_epsilon" },
+
+    { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count" },
+    { LLM_KV_ROPE_SCALE_LINEAR,             "%s.rope.scale_linear"    },
+
+    { LLM_KV_TOKENIZER_MODEL,               "tokenizer.ggml.model"              },
+    { LLM_KV_TOKENIZER_LIST,                "tokenizer.ggml.tokens"             },
+    { LLM_KV_TOKENIZER_TOKEN_TYPE,          "tokenizer.ggml.token_type"         },
+    { LLM_KV_TOKENIZER_SCORES,              "tokenizer.ggml.scores"             },
+    { LLM_KV_TOKENIZER_MERGES,              "tokenizer.ggml.merges"             },
+    { LLM_KV_TOKENIZER_BOS_ID,              "tokenizer.ggml.bos_token_id"       },
+    { LLM_KV_TOKENIZER_EOS_ID,              "tokenizer.ggml.eos_token_id"       },
+    { LLM_KV_TOKENIZER_UNK_ID,              "tokenizer.ggml.unknown_token_id"   },
+    { LLM_KV_TOKENIZER_SEP_ID,              "tokenizer.ggml.seperator_token_id" },
+    { LLM_KV_TOKENIZER_PAD_ID,              "tokenizer.ggml.padding_token_id"   },
+    { LLM_KV_TOKENIZER_HF_JSON,             "tokenizer.huggingface.json"        },
+    { LLM_KV_TOKENIZER_RWKV,                "tokenizer.rwkv.world"              },
+};
+
+struct LLM_KV {
+    LLM_KV(llm_arch arch) : arch(arch) {}
+
+    llm_arch arch;
+
+    std::string operator()(llm_kv kv) const {
+        return ::format(LLM_KV_NAMES[kv].c_str(), LLM_ARCH_NAMES[arch].c_str());
+    }
+};
+
+enum llm_tensor {
+    LLM_TENSOR_TOKEN_EMBD,
+    LLM_TENSOR_POS_EMBD,
+    LLM_TENSOR_OUTPUT,
+    LLM_TENSOR_OUTPUT_NORM,
+    LLM_TENSOR_ROPE_FREQS,
+    LLM_TENSOR_ATTN_Q,
+    LLM_TENSOR_ATTN_K,
+    LLM_TENSOR_ATTN_V,
+    LLM_TENSOR_ATTN_QKV,
+    LLM_TENSOR_ATTN_OUT,
+    LLM_TENSOR_ATTN_NORM,
+    LLM_TENSOR_ATTN_NORM_2,
+    LLM_TENSOR_ATTN_ROT_EMBD,
+    LLM_TENSOR_FFN_GATE,
+    LLM_TENSOR_FFN_DOWN,
+    LLM_TENSOR_FFN_UP,
+    LLM_TENSOR_FFN_NORM,
+};
+
+static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
+    {
+        LLM_ARCH_LLAMA,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
+    {
+        LLM_ARCH_FALCON,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_NORM_2,     "blk.%d.attn_norm_2" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
+};
+
+static llm_arch llm_arch_from_string(const std::string & name) {
+    for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT
+        if (kv.second == name) {
+            return kv.first;
+        }
+    }
+
+    return LLM_ARCH_UNKNOWN;
+}
+
+// helper to handle gguf constants
+// usage:
+//
+//   const auto tn = LLM_TN(LLM_ARCH_LLAMA);
+//
+//   std::string name = tn(LLM_TENSOR_OUTPUT);                     -> "output"
+//   std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias");         -> "token_embd.bias"
+//   std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3);     -> "blk.3.attn_norm.weight"
+//
+struct LLM_TN {
+    LLM_TN(llm_arch arch) : arch(arch) {}
+
+    llm_arch arch;
+
+    std::string operator()(llm_tensor tensor) const {
+        return LLM_TENSOR_NAMES[arch].at(tensor);
+    }
+
+    std::string operator()(llm_tensor tensor, const std::string & suffix) const {
+        return LLM_TENSOR_NAMES[arch].at(tensor) + "." + suffix;
+    }
+
+    std::string operator()(llm_tensor tensor, int bid) const {
+        return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid);
+    }
+
+    std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const {
+        return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix;
+    }
+};
+
+//
+// gguf helpers
+//
+
+#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
+{ \
+    const std::string skey(key); \
+    const int kid = gguf_find_key(ctx, skey.c_str()); \
+    if (kid >= 0) { \
+        enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
+        if (ktype != (type)) { \
+            throw std::runtime_error(format("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype))); \
+        } \
+        (dst) = func(ctx, kid); \
+    } else if (req) { \
+        throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \
+    } \
+}
+
 //
 // ggml helpers
 //
@@ -589,12 +827,13 @@ enum e_model {
     MODEL_7B,
     MODEL_13B,
     MODEL_30B,
+    MODEL_40B,
     MODEL_65B,
     MODEL_70B,
 };
 
 static const size_t kB = 1024;
-static const size_t MB = 1024*1024;
+static const size_t MB = kB*kB;
 
 // default hparams (LLaMA 7B)
 struct llama_hparams {
@@ -608,6 +847,7 @@ struct llama_hparams {
     uint32_t n_rot       = 64;
     uint32_t n_ff        = 11008;
 
+    float f_norm_eps     = 1e-5;
     float f_norm_rms_eps = 1e-5;
 
     float rope_freq_base  = 10000.0f;
@@ -641,21 +881,25 @@ struct llama_hparams {
 
 struct llama_layer {
     // normalization
-    struct ggml_tensor * attention_norm;
+    struct ggml_tensor * attn_norm;
+    struct ggml_tensor * attn_norm_b;
+    struct ggml_tensor * attn_norm_2;
+    struct ggml_tensor * attn_norm_2_b;
 
     // attention
     struct ggml_tensor * wq;
     struct ggml_tensor * wk;
     struct ggml_tensor * wv;
     struct ggml_tensor * wo;
+    struct ggml_tensor * wqkv;
 
     // normalization
     struct ggml_tensor * ffn_norm;
 
     // ff
-    struct ggml_tensor * w1;
-    struct ggml_tensor * w2;
-    struct ggml_tensor * w3;
+    struct ggml_tensor * w1; // ffn_gate
+    struct ggml_tensor * w2; // ffn_down
+    struct ggml_tensor * w3; // ffn_up
 };
 
 struct llama_kv_cache {
@@ -681,10 +925,6 @@ struct llama_kv_cache {
 };
 
 struct llama_vocab {
-    // TODO:
-    // - add a vector of merges
-    //   so that we can pass it to different types of tokenizers with a common interface
-
     using id    = int32_t;
     using token = std::string;
     using ttype = llama_token_type;
@@ -695,11 +935,13 @@ struct llama_vocab {
         ttype type;
     };
 
-    llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
+    enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
 
     std::unordered_map<token, id> token_to_id;
     std::vector<token_data>       id_to_token;
 
+    std::map<std::pair<std::string, std::string>, int> bpe_ranks;
+
     // default LLaMA special tokens
     id special_bos_id = 1;
     id special_eos_id = 2;
@@ -708,21 +950,40 @@ struct llama_vocab {
     id special_pad_id = -1;
 
     id linefeed_id = 13;
+
+    int find_bpe_rank(std::string token_left, std::string token_right) const {
+        replace_all(token_left,  " ",  "Ä ");
+        replace_all(token_left,  "\n", "ÄŠ");
+        replace_all(token_right, " ",  "Ä ");
+        replace_all(token_right, "\n", "ÄŠ");
+
+        auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
+        if (it == bpe_ranks.end()) {
+            return -1;
+        }
+
+        return it->second;
+    }
 };
 
 struct llama_model {
     e_model     type  = MODEL_UNKNOWN;
+    llm_arch    arch  = LLM_ARCH_UNKNOWN;
     llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
 
+    std::string name = "n/a";
+
     llama_hparams hparams;
     llama_vocab   vocab;
 
     struct ggml_tensor * tok_embeddings;
 
-    struct ggml_tensor * norm;
+    struct ggml_tensor * output_norm;
+    struct ggml_tensor * output_norm_b;
     struct ggml_tensor * output;
 
     std::vector<llama_layer> layers;
+
     int n_gpu_layers;
 
     // context
@@ -800,8 +1061,6 @@ struct llama_context {
     // key + value cache for the self attention
     struct llama_kv_cache kv_self;
 
-    size_t mem_per_token = 0;
-
     // decode output (2-dimensional array: [n_tokens][n_vocab])
     std::vector<float> logits;
     bool logits_all = false;
@@ -880,11 +1139,11 @@ static bool llama_kv_cache_init(
 // model loading and saving
 //
 
-enum llama_file_version {
+enum llama_fver {
     GGUF_FILE_VERSION_V1 = 1,
 };
 
-static const char * llama_file_version_name(llama_file_version version) {
+static const char * llama_file_version_name(llama_fver version) {
     switch (version) {
         case GGUF_FILE_VERSION_V1: return "GGUF V1 (latest)";
     }
@@ -892,11 +1151,11 @@ static const char * llama_file_version_name(llama_file_version version) {
     return "unknown";
 }
 
-static std::string llama_format_tensor_shape(const std::vector<uint32_t> & ne) {
+static std::string llama_format_tensor_shape(const std::vector<int64_t> & ne) {
     char buf[256];
-    snprintf(buf, sizeof(buf), "%5u", ne.at(0));
+    snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0));
     for (size_t i = 1; i < ne.size(); i++) {
-        snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5u", ne.at(i));
+        snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i));
     }
     return buf;
 }
@@ -919,9 +1178,9 @@ struct llama_model_loader {
 
     bool use_mmap = false;
 
-    llama_file file;
+    llama_file  file;
     llama_ftype ftype;
-    llama_file_version fver;
+    llama_fver  fver;
 
     std::unique_ptr<llama_mmap> mapping;
 
@@ -942,7 +1201,7 @@ struct llama_model_loader {
         n_kv      = gguf_get_n_kv(ctx_gguf);
         n_tensors = gguf_get_n_tensors(ctx_gguf);
 
-        fver = (enum llama_file_version) gguf_get_version(ctx_gguf);
+        fver = (enum llama_fver ) gguf_get_version(ctx_gguf);
 
         for (int i = 0; i < n_tensors; i++) {
             const char * name = gguf_get_tensor_name(ctx_gguf, i);
@@ -1039,6 +1298,21 @@ struct llama_model_loader {
         }
     }
 
+    std::string get_arch_name() const {
+        const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
+
+        std::string arch_name;
+        GGUF_GET_KEY(ctx_gguf, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE));
+
+        return arch_name;
+    }
+
+    enum llm_arch get_arch() const {
+        const std::string arch_name = get_arch_name();
+
+        return llm_arch_from_string(arch_name);
+    }
+
     const char * get_tensor_name(int i) const {
         return gguf_get_tensor_name(ctx_gguf, i);
     }
@@ -1076,7 +1350,7 @@ struct llama_model_loader {
         return tensor;
     }
 
-    struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<uint32_t> & ne, ggml_backend backend) {
+    struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, ggml_backend backend) {
         struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str());
 
         if (cur == NULL) {
@@ -1244,228 +1518,279 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_7B:  return "7B";
         case MODEL_13B: return "13B";
         case MODEL_30B: return "30B";
+        case MODEL_40B: return "40B";
         case MODEL_65B: return "65B";
         case MODEL_70B: return "70B";
-        default: GGML_ASSERT(false);
+        default:        return "?B";
     }
 }
 
-static void llama_model_load_internal(
-        const std::string & fname,
+static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
+    model.arch = ml.get_arch();
+    if (model.arch == LLM_ARCH_UNKNOWN) {
+        throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'");
+    }
+}
+
+static void llm_load_hparams(
+        llama_model_loader & ml,
         llama_model & model,
-        llama_vocab & vocab,
         int n_ctx,
-        int n_batch,
-        int n_gpu_layers,
-        int main_gpu,
-        const float * tensor_split,
-        const bool mul_mat_q,
         float rope_freq_base,
-        float rope_freq_scale,
-        bool low_vram,
-        ggml_type memory_type,
-        bool use_mmap,
-        bool use_mlock,
-        bool vocab_only,
-        llama_progress_callback progress_callback,
-        void * progress_callback_user_data) {
-    model.t_start_us = ggml_time_us();
+        float rope_freq_scale) {
+    struct gguf_context * ctx = ml.ctx_gguf;
 
-    std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname, use_mmap));
-
-    model.n_gpu_layers = n_gpu_layers;
+    const auto kv = LLM_KV(model.arch);
 
     auto & hparams = model.hparams;
 
-    std::string general_name = "n/a";
-    std::string general_arch = "n/a";
+    // get general kv
+    GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
 
-    // read hparams
+    // get hparams kv
+    GGUF_GET_KEY(ctx, hparams.n_vocab,        gguf_get_arr_n,   GGUF_TYPE_ARRAY,   true, kv(LLM_KV_TOKENIZER_LIST));
+    GGUF_GET_KEY(ctx, hparams.n_ctx_train,    gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_CONTEXT_LENGTH));
+    GGUF_GET_KEY(ctx, hparams.n_embd,         gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_EMBEDDING_LENGTH));
+    GGUF_GET_KEY(ctx, hparams.n_ff,           gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_FEED_FORWARD_LENGTH));
+    GGUF_GET_KEY(ctx, hparams.n_head,         gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
+    GGUF_GET_KEY(ctx, hparams.n_layer,        gguf_get_val_u32, GGUF_TYPE_UINT32,  true, kv(LLM_KV_BLOCK_COUNT));
+
+    // n_head_kv is optional, default to n_head
+    hparams.n_head_kv = hparams.n_head;
+    GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
+
+    // TODO: manually setting rope scale should override this
+    // rope_freq_scale (inverse of the kv) is optional
     {
-        struct gguf_context * ctx = ml->ctx_gguf;
-
-#define GGUF_GET(dst, func, type, req, key) \
-        { \
-            const int kid = gguf_find_key(ctx, key); \
-            if (kid >= 0) { \
-                enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
-                if (ktype != (type)) { \
-                    throw std::runtime_error(format("key %s has wrong type: %s", key, gguf_type_name(ktype))); \
-                } \
-                (dst) = func(ctx, kid); \
-            } else if (req) { \
-                throw std::runtime_error(format("key not found in model: %s", key)); \
-            } \
+        float ropescale = 1.0f;
+        GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
+        if (ropescale != 1.0f) {
+            rope_freq_scale = 1.0f/ropescale;
         }
+    }
 
-        std::string tokenizer_name;
-        GGUF_GET(tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "tokenizer.ggml.model");
+    // sanity check for n_rot (optional)
+    {
+        hparams.n_rot = hparams.n_embd / hparams.n_head;
 
-        if (tokenizer_name == "llama") {
-            vocab.type = LLAMA_VOCAB_TYPE_SPM;
-        } else if (tokenizer_name == "gpt2") {
-            vocab.type = LLAMA_VOCAB_TYPE_BPE;
-        } else {
-            LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
-            LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
-            vocab.type = LLAMA_VOCAB_TYPE_SPM;
+        GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
+
+        if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
+            throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head));
         }
+    }
 
-        // get hparams kv
-        GGUF_GET(hparams.n_vocab,        gguf_get_arr_n,   GGUF_TYPE_ARRAY,   true, "tokenizer.ggml.tokens");
-        GGUF_GET(hparams.n_ctx_train,    gguf_get_val_u32, GGUF_TYPE_UINT32,  true, "llama.context_length");
-        GGUF_GET(hparams.n_embd,         gguf_get_val_u32, GGUF_TYPE_UINT32,  true, "llama.embedding_length");
-        GGUF_GET(hparams.n_ff,           gguf_get_val_u32, GGUF_TYPE_UINT32,  true, "llama.feed_forward_length");
-        GGUF_GET(hparams.n_head,         gguf_get_val_u32, GGUF_TYPE_UINT32,  true, "llama.attention.head_count");
-        GGUF_GET(hparams.n_layer,        gguf_get_val_u32, GGUF_TYPE_UINT32,  true, "llama.block_count");
-        GGUF_GET(hparams.n_rot,          gguf_get_val_u32, GGUF_TYPE_UINT32,  true, "llama.rope.dimension_count");
-        GGUF_GET(hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, "llama.attention.layer_norm_rms_epsilon");
+    // arch-specific KVs
+    switch (model.arch) {
+        case LLM_ARCH_LLAMA:
+            {
+                GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
+
+                switch (hparams.n_layer) {
+                    case 26: model.type = e_model::MODEL_3B; break;
+                    case 32: model.type = e_model::MODEL_7B; break;
+                    case 40: model.type = e_model::MODEL_13B; break;
+                    case 60: model.type = e_model::MODEL_30B; break;
+                    case 80: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_65B : e_model::MODEL_70B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
+        case LLM_ARCH_FALCON:
+            {
+                GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
 
-        // n_head_kv is optional, default to n_head
-        hparams.n_head_kv = hparams.n_head;
-        GGUF_GET(hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "llama.attention.head_count_kv");
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_7B; break;
+                    case 60: model.type = e_model::MODEL_40B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
+        default: (void)0;
+    };
 
-        // TODO: manually setting rope scale should override this
-        // rope_freq_scale (inverse of the kv) is optional
-        float ropescale = 1.0f;
-        GGUF_GET(ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, "llama.rope.scale_linear");
-        if (ropescale != 1.0f) {
-            rope_freq_scale = 1.0f/ropescale;
-        }
+    model.ftype = ml.ftype;
 
-        // get general kv
-        GGUF_GET(general_name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name");
-        GGUF_GET(general_arch, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.architecture");
+    hparams.n_ctx           = n_ctx;
+    hparams.rope_freq_base  = rope_freq_base;
+    hparams.rope_freq_scale = rope_freq_scale;
+}
 
-        // special tokens
-        GGUF_GET(vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.bos_token_id");
-        GGUF_GET(vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.eos_token_id");
-        GGUF_GET(vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.unknown_token_id");
-        GGUF_GET(vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.separator_token_id");
-        GGUF_GET(vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.padding_token_id");
+// TODO: This should probably be in llama.h
+static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape);
 
-#undef GGUF_GET
+static void llm_load_vocab(
+        llama_model_loader & ml,
+        llama_model & model) {
+    auto & vocab = model.vocab;
 
-        switch (hparams.n_layer) {
-            case 26: model.type = e_model::MODEL_3B; break;
-            case 32: model.type = e_model::MODEL_7B; break;
-            case 40: model.type = e_model::MODEL_13B; break;
-            case 60: model.type = e_model::MODEL_30B; break;
-            case 80: model.type = e_model::MODEL_65B; break;
-            default:
-                {
-                    if (hparams.n_layer < 32) {
-                        model.type = e_model::MODEL_7B;
-                    }
-                } break;
-        }
+    struct gguf_context * ctx = ml.ctx_gguf;
 
-        model.ftype = ml->ftype;
+    const auto kv = LLM_KV(model.arch);
 
-        hparams.n_ctx = n_ctx;
+    const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
+    if (token_idx == -1) {
+        throw std::runtime_error("cannot find tokenizer vocab in model file\n");
+    }
 
-        // LLaMAv2
-        // TODO: probably not needed
-        {
-            const auto n_gqa = hparams.n_gqa();
+    const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
+    if (score_idx == -1) {
+        throw std::runtime_error("cannot find tokenizer scores in model file\n");
+    }
 
-            if (model.type == e_model::MODEL_65B && n_gqa == 8) {
-                LLAMA_LOG_WARN("%s: assuming 70B model based on GQA == %d\n", __func__, n_gqa);
-                model.type = e_model::MODEL_70B;
-            }
-        }
+    const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
 
-        hparams.rope_freq_base  = rope_freq_base;
-        hparams.rope_freq_scale = rope_freq_scale;
+    const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
+    if (toktype_idx == -1) {
+        throw std::runtime_error("cannot find token type list in GGUF file\n");
     }
 
-    // read vocab
+    const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
+
+    // determine vocab type
     {
-        struct gguf_context * ctx = ml->ctx_gguf;
+        std::string tokenizer_name;
 
-        vocab.id_to_token.resize(hparams.n_vocab);
+        GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
 
-        const int token_idx = gguf_find_key(ctx, "tokenizer.ggml.tokens");
-        if (token_idx == -1) {
-            throw std::runtime_error("cannot find tokenizer vocab in model file\n");
-        }
+        if (tokenizer_name == "llama") {
+            vocab.type = LLAMA_VOCAB_TYPE_SPM;
 
-        const int score_idx = gguf_find_key(ctx, "tokenizer.ggml.scores");
-        if (score_idx == -1) {
-            throw std::runtime_error("cannot find tokenizer scores in model file\n");
-        }
+            // default special tokens
+            vocab.special_bos_id = 1;
+            vocab.special_eos_id = 2;
+            vocab.special_unk_id = 0;
+            vocab.special_sep_id = -1;
+            vocab.special_pad_id = -1;
+        } else if (tokenizer_name == "gpt2") {
+            vocab.type = LLAMA_VOCAB_TYPE_BPE;
 
-        const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
+            // read bpe merges and populate bpe ranks
+            const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
+            if (merges_keyidx == -1) {
+                throw std::runtime_error("cannot find tokenizer merges in model file\n");
+            }
 
-        const int toktype_idx = gguf_find_key(ctx, "tokenizer.ggml.token_type");
-        if (toktype_idx == -1) {
-            throw std::runtime_error("cannot find token type list in GGUF file\n");
-        }
+            const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
 
-        const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
+            for (int i = 0; i < n_merges; i++) {
+                const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
 
-        for (uint32_t i = 0; i < hparams.n_vocab; i++) {
-            std::string word = gguf_get_arr_str(ctx, token_idx, i);
+                std::string first;
+                std::string second;
 
-            vocab.token_to_id[word] = i;
+                const size_t pos = word.find(' ', 1);
 
-            auto & token_data = vocab.id_to_token[i];
-            token_data.text  = std::move(word);
-            token_data.score = scores[i];
-            token_data.type  = (llama_token_type) toktypes[i];
+                if (pos != std::string::npos) {
+                    first  = word.substr(0, pos);
+                    second = word.substr(pos + 1);
+                }
 
-            // determine the newline token: 0x0A == 10 == '\n'
-            if (token_data.text == "<0x0A>") {
-                vocab.linefeed_id = i;
+                vocab.bpe_ranks.emplace(std::make_pair(first, second), i);
             }
+
+            // default special tokens
+            vocab.special_bos_id = 11;
+            vocab.special_eos_id = 11;
+            vocab.special_unk_id = -1;
+            vocab.special_sep_id = -1;
+            vocab.special_pad_id = -1;
+        } else {
+            LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
+            LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
+
+            vocab.type = LLAMA_VOCAB_TYPE_SPM;
         }
     }
 
-    {
-        // hparams
-        LLAMA_LOG_INFO("%s: format       = %s\n",     __func__, llama_file_version_name(ml->fver));
-        LLAMA_LOG_INFO("%s: arch         = %s\n",     __func__, general_arch.c_str());
-        LLAMA_LOG_INFO("%s: vocab type   = %s\n",     __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
-        LLAMA_LOG_INFO("%s: n_vocab      = %u\n",     __func__, hparams.n_vocab);
-        LLAMA_LOG_INFO("%s: n_ctx_train  = %u\n",     __func__, hparams.n_ctx_train);
-        LLAMA_LOG_INFO("%s: n_ctx        = %u\n",     __func__, hparams.n_ctx);
-        LLAMA_LOG_INFO("%s: n_embd       = %u\n",     __func__, hparams.n_embd);
-        LLAMA_LOG_INFO("%s: n_head       = %u\n",     __func__, hparams.n_head);
-        LLAMA_LOG_INFO("%s: n_head_kv    = %u\n",     __func__, hparams.n_head_kv);
-        LLAMA_LOG_INFO("%s: n_layer      = %u\n",     __func__, hparams.n_layer);
-        LLAMA_LOG_INFO("%s: n_rot        = %u\n",     __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
-        LLAMA_LOG_INFO("%s: n_gqa        = %u\n",     __func__, hparams.n_gqa());
-        LLAMA_LOG_INFO("%s: f_norm_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
-        LLAMA_LOG_INFO("%s: n_ff         = %u\n",     __func__, hparams.n_ff);
-        LLAMA_LOG_INFO("%s: freq_base    = %.1f\n",   __func__, hparams.rope_freq_base);
-        LLAMA_LOG_INFO("%s: freq_scale   = %g\n",     __func__, hparams.rope_freq_scale);
-        LLAMA_LOG_INFO("%s: model type   = %s\n",     __func__, llama_model_type_name(model.type));
-        LLAMA_LOG_INFO("%s: model ftype  = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str());
-        LLAMA_LOG_INFO("%s: model size   = %.2f B\n", __func__, ml->n_elements*1e-9);
-
-        // general kv
-        LLAMA_LOG_INFO("%s: general.name = %s\n",    __func__, general_name.c_str());
-
-        // special tokens
-        if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
-        if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); }
-        if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
-        if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
-        if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
-        if (vocab.linefeed_id    != -1) { LLAMA_LOG_INFO( "%s: LF token  = %d '%s'\n", __func__, vocab.linefeed_id,    vocab.id_to_token[vocab.linefeed_id].text.c_str() );    }
-    }
-
-    if (vocab_only) {
-        LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
-        return;
+    const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
+
+    vocab.id_to_token.resize(n_vocab);
+
+    for (uint32_t i = 0; i < n_vocab; i++) {
+        std::string word = gguf_get_arr_str(ctx, token_idx, i);
+
+        vocab.token_to_id[word] = i;
+
+        auto & token_data = vocab.id_to_token[i];
+        token_data.text  = std::move(word);
+        token_data.score = scores[i];
+        token_data.type  = (llama_token_type) toktypes[i];
     }
 
-    auto & ctx = model.ctx;
+    // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
+    vocab.linefeed_id = llama_tokenize_internal(vocab, "\n", false, false)[0];
+
+    // special tokens
+    GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID));
+    GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID));
+    GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
+    GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
+    GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
+}
+
+static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
+    const auto & hparams = model.hparams;
+    const auto & vocab   = model.vocab;
+
+    // hparams
+    LLAMA_LOG_INFO("%s: format         = %s\n",     __func__, llama_file_version_name(ml.fver));
+    LLAMA_LOG_INFO("%s: arch           = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
+    LLAMA_LOG_INFO("%s: vocab type     = %s\n",     __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
+    LLAMA_LOG_INFO("%s: n_vocab        = %u\n",     __func__, hparams.n_vocab);
+    LLAMA_LOG_INFO("%s: n_merges       = %u\n",     __func__, (int) vocab.bpe_ranks.size());
+    LLAMA_LOG_INFO("%s: n_ctx_train    = %u\n",     __func__, hparams.n_ctx_train);
+    LLAMA_LOG_INFO("%s: n_ctx          = %u\n",     __func__, hparams.n_ctx);
+    LLAMA_LOG_INFO("%s: n_embd         = %u\n",     __func__, hparams.n_embd);
+    LLAMA_LOG_INFO("%s: n_head         = %u\n",     __func__, hparams.n_head);
+    LLAMA_LOG_INFO("%s: n_head_kv      = %u\n",     __func__, hparams.n_head_kv);
+    LLAMA_LOG_INFO("%s: n_layer        = %u\n",     __func__, hparams.n_layer);
+    LLAMA_LOG_INFO("%s: n_rot          = %u\n",     __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
+    LLAMA_LOG_INFO("%s: n_gqa          = %u\n",     __func__, hparams.n_gqa());
+    LLAMA_LOG_INFO("%s: f_norm_eps     = %.1e\n",   __func__, hparams.f_norm_eps);
+    LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n",   __func__, hparams.f_norm_rms_eps);
+    LLAMA_LOG_INFO("%s: n_ff           = %u\n",     __func__, hparams.n_ff);
+    LLAMA_LOG_INFO("%s: freq_base      = %.1f\n",   __func__, hparams.rope_freq_base);
+    LLAMA_LOG_INFO("%s: freq_scale     = %g\n",     __func__, hparams.rope_freq_scale);
+    LLAMA_LOG_INFO("%s: model type     = %s\n",     __func__, llama_model_type_name(model.type));
+    LLAMA_LOG_INFO("%s: model ftype    = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str());
+    LLAMA_LOG_INFO("%s: model size     = %.2f B\n", __func__, ml.n_elements*1e-9);
+
+    // general kv
+    LLAMA_LOG_INFO("%s: general.name   = %s\n",    __func__, model.name.c_str());
+
+    // special tokens
+    if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
+    if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); }
+    if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
+    if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
+    if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
+    if (vocab.linefeed_id    != -1) { LLAMA_LOG_INFO( "%s: LF token  = %d '%s'\n", __func__, vocab.linefeed_id,    vocab.id_to_token[vocab.linefeed_id].text.c_str() );    }
+}
+
+static void llm_load_tensors(
+        llama_model_loader & ml,
+        llama_model & model,
+        int n_batch,
+        int n_gpu_layers,
+        int main_gpu,
+        const float * tensor_split,
+        const bool mul_mat_q,
+        bool low_vram,
+        ggml_type memory_type,
+        bool use_mlock,
+        llama_progress_callback progress_callback,
+        void * progress_callback_user_data) {
+    model.t_start_us = ggml_time_us();
+
+    auto & ctx     = model.ctx;
+    auto & hparams = model.hparams;
+
+    model.n_gpu_layers = n_gpu_layers;
 
     size_t ctx_size;
     size_t mmapped_size;
 
-    ml->calc_sizes(ctx_size, mmapped_size);
+    ml.calc_sizes(ctx_size, mmapped_size);
 
     LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0);
 
@@ -1480,7 +1805,7 @@ static void llama_model_load_internal(
         struct ggml_init_params params = {
             /*.mem_size   =*/ model.buf.size,
             /*.mem_buffer =*/ model.buf.data,
-            /*.no_alloc   =*/ ml->use_mmap,
+            /*.no_alloc   =*/ ml.use_mmap,
         };
 
         model.ctx = ggml_init(params);
@@ -1509,75 +1834,146 @@ static void llama_model_load_internal(
     // prepare memory for the weights
     size_t vram_weights = 0;
     {
-        const uint32_t n_embd     = hparams.n_embd;
-        const uint32_t n_embd_gqa = hparams.n_embd_gqa();
-        const uint32_t n_layer    = hparams.n_layer;
-        const uint32_t n_vocab    = hparams.n_vocab;
+        const int64_t n_embd     = hparams.n_embd;
+        const int64_t n_embd_gqa = hparams.n_embd_gqa();
+        const int64_t n_layer    = hparams.n_layer;
+        const int64_t n_vocab    = hparams.n_vocab;
 
-        model.tok_embeddings = ml->create_tensor(ctx, TN_TOKEN_EMBD, {n_embd, n_vocab}, GGML_BACKEND_CPU);
+        const auto tn = LLM_TN(model.arch);
 
-        // "output" tensor
-        {
-            ggml_backend backend_norm;
-            ggml_backend backend_output;
-            if (n_gpu_layers > int(n_layer)) { // NOLINT
-                // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
-                // on Windows however this is detrimental unless everything is on the GPU
+        switch (model.arch) {
+            case LLM_ARCH_LLAMA:
+                {
+                    model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
+
+                    // output
+                    {
+                        ggml_backend backend_norm;
+                        ggml_backend backend_output;
+
+                        if (n_gpu_layers > int(n_layer)) {
+                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
+                            // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32
-                backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
 #else
-                backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+                            backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
 #endif // _WIN32
 
-                backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
-            } else {
-                backend_norm = GGML_BACKEND_CPU;
-                backend_output = GGML_BACKEND_CPU;
-            }
+                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                        } else {
+                            backend_norm   = GGML_BACKEND_CPU;
+                            backend_output = GGML_BACKEND_CPU;
+                        }
 
-            model.norm   = ml->create_tensor(ctx, TN_OUTPUT_NORM, {n_embd},          backend_norm);
-            model.output = ml->create_tensor(ctx, TN_OUTPUT,      {n_embd, n_vocab}, backend_output);
-            if (backend_norm == GGML_BACKEND_GPU) {
-                vram_weights += ggml_nbytes(model.norm);
-            }
-            if (backend_output == GGML_BACKEND_GPU_SPLIT) {
-                vram_weights += ggml_nbytes(model.output);
-            }
-        }
+                        model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd},          backend_norm);
+                        model.output      = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
+
+                        if (backend_norm == GGML_BACKEND_GPU) {
+                            vram_weights += ggml_nbytes(model.output_norm);
+                        }
+                        if (backend_output == GGML_BACKEND_GPU_SPLIT) {
+                            vram_weights += ggml_nbytes(model.output);
+                        }
+                    }
 
-        const uint32_t n_ff = hparams.n_ff;
+                    const uint32_t n_ff = hparams.n_ff;
 
-        const int i_gpu_start = n_layer - n_gpu_layers;
+                    const int i_gpu_start = n_layer - n_gpu_layers;
 
-        model.layers.resize(n_layer);
-        for (uint32_t i = 0; i < n_layer; ++i) {
-            const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
-            const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+                    model.layers.resize(n_layer);
 
-            auto & layer = model.layers[i];
-            layer.attention_norm = ml->create_tensor(ctx, format(TN_ATTN_NORM, i), {n_embd}, backend);
+                    for (uint32_t i = 0; i < n_layer; ++i) {
+                        const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
+                        const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
 
-            layer.wq = ml->create_tensor(ctx, format(TN_ATTN_Q, i),      {n_embd, n_embd},     backend_split);
-            layer.wk = ml->create_tensor(ctx, format(TN_ATTN_K, i),      {n_embd, n_embd_gqa}, backend_split);
-            layer.wv = ml->create_tensor(ctx, format(TN_ATTN_V, i),      {n_embd, n_embd_gqa}, backend_split);
-            layer.wo = ml->create_tensor(ctx, format(TN_ATTN_OUTPUT, i), {n_embd, n_embd},     backend_split);
+                        auto & layer = model.layers[i];
 
-            layer.ffn_norm = ml->create_tensor(ctx, format(TN_FFN_NORM, i), {n_embd}, backend);
+                        layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
 
-            layer.w1 = ml->create_tensor(ctx, format(TN_FFN_GATE, i), {n_embd,   n_ff}, backend_split);
-            layer.w2 = ml->create_tensor(ctx, format(TN_FFN_DOWN, i), {  n_ff, n_embd}, backend_split);
-            layer.w3 = ml->create_tensor(ctx, format(TN_FFN_UP, i),   {n_embd,   n_ff}, backend_split);
+                        layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd},     backend_split);
+                        layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, backend_split);
+                        layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, backend_split);
+                        layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd},     backend_split);
 
-            if (backend == GGML_BACKEND_GPU) {
-                vram_weights +=
-                    ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk)             +
-                    ggml_nbytes(layer.wv)             + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
-                    ggml_nbytes(layer.w1)             + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
-            }
-        }
+                        layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
+
+                        layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, backend_split);
+                        layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split);
+                        layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
+
+                        if (backend == GGML_BACKEND_GPU) {
+                            vram_weights +=
+                                ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk)       +
+                                ggml_nbytes(layer.wv)        + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
+                                ggml_nbytes(layer.w1)        + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
+                        }
+                    }
+                } break;
+            case LLM_ARCH_FALCON:
+                {
+                    // TODO: CPU-only for now
+
+                    model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
+
+                    // output
+                    {
+                        ggml_backend backend_norm;
+                        ggml_backend backend_output;
+
+                        if (n_gpu_layers > int(n_layer)) {
+                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
+                            // on Windows however this is detrimental unless everything is on the GPU
+#ifndef _WIN32
+                            backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+#else
+                            backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+#endif // _WIN32
+
+                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                        } else {
+                            backend_norm   = GGML_BACKEND_CPU;
+                            backend_output = GGML_BACKEND_CPU;
+                        }
+
+                        model.output_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd},          backend_norm);
+                        model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd},          backend_norm);
+                        model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
+                    }
+
+                    const uint32_t n_ff = hparams.n_ff;
+
+                    const int i_gpu_start = n_layer - n_gpu_layers;
+
+                    model.layers.resize(n_layer);
+
+                    for (uint32_t i = 0; i < n_layer; ++i) {
+                        const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
+                        const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, backend);
+                        layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, backend);
+
+                        if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) {
+                            layer.attn_norm_2   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend);
+                            layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, backend);
+                        }
+
+                        layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
+                        layer.wo   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd},                backend_split);
+
+                        layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split);
+                        layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
+                    }
+                } break;
+            default:
+                throw std::runtime_error("unknown architecture");
+        };
     }
 
-    ml->done_getting_tensors();
+    ml.done_getting_tensors();
 
     // print memory requirements
     {
@@ -1589,8 +1985,7 @@ static void llama_model_load_internal(
             mmapped_size - vram_weights; // weights in VRAM not in memory
 
         // this is the memory required by one llama_state
-        const size_t mem_required_state =
-            scale*hparams.kv_size();
+        const size_t mem_required_state = scale*hparams.kv_size();
 
         LLAMA_LOG_INFO("%s: mem required  = %7.2f MB (+ %7.2f MB per state)\n", __func__,
                 mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
@@ -1640,8 +2035,8 @@ static void llama_model_load_internal(
     }
 
     // populate `tensors_by_name`
-    for (int i = 0; i < ml->n_tensors; ++i) {
-        struct ggml_tensor * cur = ggml_get_tensor(ctx, ml->get_tensor_name(i));
+    for (int i = 0; i < ml.n_tensors; ++i) {
+        struct ggml_tensor * cur = ggml_get_tensor(ctx, ml.get_tensor_name(i));
         model.tensors_by_name.emplace_back(ggml_get_name(cur), cur);
     }
 
@@ -1652,13 +2047,13 @@ static void llama_model_load_internal(
     }
 #endif
 
-    ml->load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL);
+    ml.load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL);
 
     if (progress_callback) {
         progress_callback(1.0f, progress_callback_user_data);
     }
 
-    model.mapping = std::move(ml->mapping);
+    model.mapping = std::move(ml.mapping);
 
     // loading time will be recalculate after the first eval, so
     // we take page faults deferred by mmap() into consideration
@@ -1668,7 +2063,6 @@ static void llama_model_load_internal(
 static bool llama_model_load(
         const std::string & fname,
         llama_model & model,
-        llama_vocab & vocab,
         int n_ctx,
         int n_batch,
         int n_gpu_layers,
@@ -1685,17 +2079,36 @@ static bool llama_model_load(
         llama_progress_callback progress_callback,
         void *progress_callback_user_data) {
     try {
-        llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers,
-                                  main_gpu, tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, low_vram, memory_type,
-                                  use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
-        return true;
+        std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname, use_mmap));
+
+        llm_load_arch   (*ml, model);
+        llm_load_hparams(*ml, model, n_ctx, rope_freq_base, rope_freq_scale);
+        llm_load_vocab  (*ml, model);
+
+        llm_load_print_meta(*ml, model);
+
+        if (model.hparams.n_vocab != model.vocab.id_to_token.size()) {
+            throw std::runtime_error("vocab size mismatch");
+        }
+
+        if (vocab_only) {
+            LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
+            return true;
+        }
+
+        llm_load_tensors(
+                *ml, model, n_batch, n_gpu_layers,
+                main_gpu, tensor_split, mul_mat_q, low_vram, memory_type,
+                use_mlock, progress_callback, progress_callback_user_data);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
         return false;
     }
+
+    return true;
 }
 
-static struct ggml_cgraph * llama_build_graph(
+static struct ggml_cgraph * llm_build_llama(
          llama_context & lctx,
      const llama_token * tokens,
            const float * embd,
@@ -1729,8 +2142,7 @@ static struct ggml_cgraph * llama_build_graph(
 
     const int n_gpu_layers = model.n_gpu_layers;
 
-    auto & mem_per_token = lctx.mem_per_token;
-    auto & buf_compute   = lctx.buf_compute;
+    auto & buf_compute = lctx.buf_compute;
 
     struct ggml_init_params params = {
         /*.mem_size   =*/ buf_compute.size,
@@ -1820,8 +2232,8 @@ static struct ggml_cgraph * llama_build_graph(
             offload_func(cur);
             ggml_set_name(cur, "rms_norm_0");
 
-            // cur = cur*attention_norm(broadcasted)
-            cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
+            // cur = cur*attn_norm(broadcasted)
+            cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
             offload_func(cur);
             ggml_set_name(cur, "attention_norm_0");
         }
@@ -1865,17 +2277,372 @@ static struct ggml_cgraph * llama_build_graph(
                         (   n_ctx)*ggml_element_size(kv_self.v),
                         (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
                 offload_func_v(v);
-                ggml_set_name(v, "v");
+                ggml_set_name(v, "v");
+
+                // important: storing RoPE-ed version of K in the KV cache!
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+            offload_func_kq(Q);
+            ggml_set_name(Q, "Q");
+
+            struct ggml_tensor * K =
+                ggml_view_3d(ctx0, kv_self.k,
+                        n_embd_head, n_past + N, n_head_kv,
+                        ggml_element_size(kv_self.k)*n_embd_gqa,
+                        ggml_element_size(kv_self.k)*n_embd_head,
+                        ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
+            offload_func_kq(K);
+            ggml_set_name(K, "K");
+
+            // K * Q
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            offload_func_kq(KQ);
+            ggml_set_name(KQ, "KQ");
+
+            // KQ_scaled = KQ / sqrt(n_embd_head)
+            // KQ_scaled shape [n_past + N, N, n_head, 1]
+            struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
+            offload_func_kq(KQ_scaled);
+            ggml_set_name(KQ_scaled, "KQ_scaled");
+
+            // KQ_masked = mask_past(KQ_scaled)
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
+            offload_func_kq(KQ_masked);
+            ggml_set_name(KQ_masked, "KQ_masked");
+
+            // KQ = soft_max(KQ_masked)
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
+            offload_func_v(KQ_soft_max);
+            ggml_set_name(KQ_soft_max, "KQ_soft_max");
+
+            // split cached V into n_head heads
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, kv_self.v,
+                        n_past + N, n_embd_head, n_head_kv,
+                        ggml_element_size(kv_self.v)*n_ctx,
+                        ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
+                        ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
+            offload_func_v(V);
+            ggml_set_name(V, "V");
+
+#if 1
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+            offload_func_v(KQV);
+            ggml_set_name(KQV, "KQV");
+#else
+            // make V contiguous in memory to speed up the matmul, however we waste time on the copy
+            // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
+            // is there a better way?
+            struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
+#endif
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            offload_func_v(KQV_merged);
+            ggml_set_name(KQV_merged, "KQV_merged");
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            cur = ggml_cpy(ctx0,
+                    KQV_merged,
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+            offload_func_v(cur);
+            ggml_set_name(cur, "KQV_merged_contiguous");
+
+            // projection (no bias)
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].wo,
+                    cur);
+            offload_func(cur);
+            ggml_set_name(cur, "result_wo");
+        }
+
+        struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
+        offload_func(inpFF);
+        ggml_set_name(inpFF, "inpFF");
+
+        // feed-forward network
+        {
+            // norm
+            {
+                cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
+                offload_func(cur);
+                ggml_set_name(cur, "rms_norm_1");
+
+                // cur = cur*ffn_norm(broadcasted)
+                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
+                offload_func(cur);
+                ggml_set_name(cur, "ffn_norm");
+            }
+
+            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
+                    model.layers[il].w3,
+                    cur);
+            offload_func(tmp);
+            ggml_set_name(tmp, "result_w3");
+
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].w1,
+                    cur);
+            offload_func(cur);
+            ggml_set_name(cur, "result_w1");
+
+            // SILU activation
+            cur = ggml_silu(ctx0, cur);
+            offload_func(cur);
+            ggml_set_name(cur, "silu");
+
+            cur = ggml_mul(ctx0, cur, tmp);
+            offload_func(cur);
+            ggml_set_name(cur, "silu_x_result_w3");
+
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].w2,
+                    cur);
+            offload_func(cur);
+            ggml_set_name(cur, "result_w2");
+        }
+
+        cur = ggml_add(ctx0, cur, inpFF);
+        offload_func(cur);
+        ggml_set_name(cur, "inpFF_+_result_w2");
+
+        // input for next layer
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    // norm
+    {
+        cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
+        offload_func_nr(cur);
+        ggml_set_name(cur, "rms_norm_2");
+
+        // cur = cur*norm(broadcasted)
+        cur = ggml_mul(ctx0, cur, model.output_norm);
+        // offload_func_nr(cur); // TODO CPU + GPU mirrored backend
+        ggml_set_name(cur, "result_norm");
+    }
+
+    // lm_head
+    cur = ggml_mul_mat(ctx0, model.output, cur);
+    ggml_set_name(cur, "result_output");
+
+    ggml_build_forward_expand(gf, cur);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
+static struct ggml_cgraph * llm_build_falcon(
+         llama_context & lctx,
+     const llama_token * tokens,
+           const float * embd,
+                   int   n_tokens,
+                   int   n_past) {
+
+    GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
+
+    const int N = n_tokens;
+
+    const auto & model   = lctx.model;
+    const auto & hparams = model.hparams;
+
+    const auto & kv_self = lctx.kv_self;
+
+    GGML_ASSERT(!!kv_self.ctx);
+
+    const int64_t n_embd      = hparams.n_embd;
+    const int64_t n_layer     = hparams.n_layer;
+    const int64_t n_ctx       = hparams.n_ctx;
+    const int64_t n_head      = hparams.n_head;
+    const int64_t n_head_kv   = hparams.n_head_kv;
+    const int64_t n_embd_head = hparams.n_embd_head();
+    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
+
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    const float freq_base  = hparams.rope_freq_base;
+    const float freq_scale = hparams.rope_freq_scale;
+    const float norm_eps   = hparams.f_norm_eps;
+
+    const int n_gpu_layers = model.n_gpu_layers;
+
+    auto & buf_compute = lctx.buf_compute;
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ buf_compute.size,
+        /*.mem_buffer =*/ buf_compute.data,
+        /*.no_alloc   =*/ false,
+    };
+
+    params.no_alloc = true;
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor * cur;
+    struct ggml_tensor * inpL;
+
+    if (tokens) {
+        struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+
+        ggml_allocr_alloc(lctx.alloc, inp_tokens);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
+        }
+        ggml_set_name(inp_tokens, "inp_tokens");
+
+        inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
+    } else {
+#ifdef GGML_USE_MPI
+        GGML_ASSERT(false && "not implemented");
+#endif
+
+        inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
+
+        ggml_allocr_alloc(lctx.alloc, inpL);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
+        }
+    }
+
+    const int i_gpu_start = n_layer - n_gpu_layers;
+    (void) i_gpu_start;
+
+    // offload functions set the tensor output backend to GPU
+    // tensors are GPU-accelerated if any input or the output has been offloaded
+    //
+    // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
+    // in that case ggml_cuda_assign_buffers has no effect
+    offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
+    offload_func_t offload_func_kq = llama_nop;
+    offload_func_t offload_func_v  = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+    if (n_gpu_layers > n_layer) {
+        offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
+    }
+    if (n_gpu_layers > n_layer + 1) {
+        offload_func_v  = ggml_cuda_assign_buffers_no_alloc;
+    }
+    if (n_gpu_layers > n_layer + 2) {
+        offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
+    }
+#endif // GGML_USE_CUBLAS
+
+    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+    ggml_allocr_alloc(lctx.alloc, KQ_scale);
+    if (!ggml_allocr_is_measure(lctx.alloc)) {
+        ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
+    }
+    ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
+
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * attn_norm;
+
+        offload_func_t offload_func = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+        if (il >= i_gpu_start) {
+            offload_func = ggml_cuda_assign_buffers_no_alloc;
+        }
+#endif // GGML_USE_CUBLAS
+
+        // self-attention
+        // TODO: refactor into common function (shared with LLaMA)
+        {
+            attn_norm = ggml_norm(ctx0, inpL, norm_eps);
+            offload_func(attn_norm);
+
+            attn_norm = ggml_add(ctx0,
+                    ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm),
+                    model.layers[il].attn_norm_b);
+            offload_func(attn_norm->src[0]);
+            offload_func(attn_norm);
+
+            if (model.layers[il].attn_norm_2) { // Falcon-40B
+                cur = ggml_norm(ctx0, inpL, norm_eps);
+                offload_func(cur);
+
+                cur = ggml_add(ctx0,
+                        ggml_mul(ctx0, cur, model.layers[il].attn_norm_2),
+                        model.layers[il].attn_norm_2_b);
+                offload_func(cur->src[0]);
+                offload_func(cur);
+            } else { // Falcon 7B
+                cur = attn_norm;
+            }
+
+            // compute QKV
+
+            cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+            offload_func_kq(cur);
+
+            // Note that the strides for Kcur, Vcur are set up so that the
+            // resulting views are misaligned with the tensor's storage
+            // (by applying the K/V offset we shift the tensor's original
+            // view to stick out behind the viewed QKV tensor's allocated
+            // memory, so to say). This is ok because no actual accesses
+            // happen to that out-of-range memory, but it can require some
+            // trickery when trying to accurately dump these views for
+            // debugging.
+
+            const size_t wsize = ggml_type_size(cur->type);
+
+            struct ggml_tensor * tmpq = ggml_view_3d(
+                ctx0, cur, n_embd_head, n_head, N,
+                wsize * n_embd_head,
+                wsize * n_embd_head * (n_head + 2 * n_head_kv),
+                0);
+            offload_func_kq(tmpq);
+
+            struct ggml_tensor * tmpk = ggml_view_3d(
+                ctx0, cur, n_embd_head, n_head_kv, N,
+                wsize * n_embd_head,
+                wsize * n_embd_head * (n_head + 2 * n_head_kv),
+                wsize * n_embd_head *  n_head);
+            offload_func_kq(tmpk);
+
+            struct ggml_tensor * tmpv = ggml_view_3d(
+                ctx0, cur, n_embd_head, n_head_kv, N,
+                wsize * n_embd_head,
+                wsize * n_embd_head * (n_head + 2 * n_head_kv),
+                wsize * n_embd_head * (n_head +     n_head_kv));
+            offload_func_v(tmpv);
+
+            // using mode = 2 for neox mode
+            struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
+            offload_func_kq(Qcur);
+            struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
+            offload_func_kq(Kcur);
+
+            {
+                struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N));
+                offload_func_v(Vcur);
+                offload_func_v(Vcur->src[0]->src[0]);
+                ggml_set_name(Vcur, "Vcur");
+
+                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
+                offload_func_kq(k);
+                ggml_set_name(k, "k");
+
+                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
+                        (   n_ctx)*ggml_element_size(kv_self.v),
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
+                offload_func_v(v);
 
-                // important: storing RoPE-ed version of K in the KV cache!
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
             }
 
-            struct ggml_tensor * Q =
-                ggml_permute(ctx0,
-                        Qcur,
-                        0, 2, 1, 3);
+            struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
             offload_func_kq(Q);
             ggml_set_name(Q, "Q");
 
@@ -1888,28 +2655,22 @@ static struct ggml_cgraph * llama_build_graph(
             offload_func_kq(K);
             ggml_set_name(K, "K");
 
-            // K * Q
             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
             offload_func_kq(KQ);
             ggml_set_name(KQ, "KQ");
 
-            // KQ_scaled = KQ / sqrt(n_embd_head)
-            // KQ_scaled shape [n_past + N, N, n_head, 1]
             struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
             offload_func_kq(KQ_scaled);
             ggml_set_name(KQ_scaled, "KQ_scaled");
 
-            // KQ_masked = mask_past(KQ_scaled)
             struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
             offload_func_kq(KQ_masked);
             ggml_set_name(KQ_masked, "KQ_masked");
 
-            // KQ = soft_max(KQ_masked)
             struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
             offload_func_v(KQ_soft_max);
             ggml_set_name(KQ_soft_max, "KQ_soft_max");
 
-            // split cached V into n_head heads
             struct ggml_tensor * V =
                 ggml_view_3d(ctx0, kv_self.v,
                         n_past + N, n_embd_head, n_head_kv,
@@ -1919,122 +2680,101 @@ static struct ggml_cgraph * llama_build_graph(
             offload_func_v(V);
             ggml_set_name(V, "V");
 
-#if 1
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
             offload_func_v(KQV);
             ggml_set_name(KQV, "KQV");
-#else
-            // make V contiguous in memory to speed up the matmul, however we waste time on the copy
-            // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
-            // is there a better way?
-            struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
-#endif
 
-            // KQV_merged = KQV.permute(0, 2, 1, 3)
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
             offload_func_v(KQV_merged);
             ggml_set_name(KQV_merged, "KQV_merged");
 
-            // cur = KQV_merged.contiguous().view(n_embd, N)
-            cur = ggml_cpy(ctx0,
-                    KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+            cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
             offload_func_v(cur);
             ggml_set_name(cur, "KQV_merged_contiguous");
 
-            // projection (no bias)
-            cur = ggml_mul_mat(ctx0,
-                    model.layers[il].wo,
-                    cur);
+            cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
             offload_func(cur);
             ggml_set_name(cur, "result_wo");
         }
 
-        struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
-        offload_func(inpFF);
-        ggml_set_name(inpFF, "inpFF");
+        struct ggml_tensor * attn_out = cur;
 
-        // feed-forward network
+        // feed forward
         {
-            // norm
-            {
-                cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
-                offload_func(cur);
-                ggml_set_name(cur, "rms_norm_1");
-
-                // cur = cur*ffn_norm(broadcasted)
-                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
-                offload_func(cur);
-                ggml_set_name(cur, "ffn_norm");
-            }
-
-            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
-                    model.layers[il].w3,
-                    cur);
-            offload_func(tmp);
-            ggml_set_name(tmp, "result_w3");
+            struct ggml_tensor * inpFF = attn_norm;
 
-            cur = ggml_mul_mat(ctx0,
-                    model.layers[il].w1,
-                    cur);
-            offload_func(cur);
-            ggml_set_name(cur, "result_w1");
+            cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF);
 
-            // SILU activation
-            cur = ggml_silu(ctx0, cur);
+            // TODO: this is temporary needed to introduce artificial dependency between FF and ATTN
+            //       adding this, because there seems to be a bug in the Metal concurrency optimization
+            //       without this line, the results are non-deterministic and wrong
+            cur->src[2] = attn_out;
             offload_func(cur);
-            ggml_set_name(cur, "silu");
 
-            cur = ggml_mul(ctx0, cur, tmp);
+            cur = ggml_gelu(ctx0, cur);
             offload_func(cur);
-            ggml_set_name(cur, "silu_x_result_w3");
-
-            cur = ggml_mul_mat(ctx0,
-                    model.layers[il].w2,
-                    cur);
+            cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
             offload_func(cur);
-            ggml_set_name(cur, "result_w2");
         }
 
-        cur = ggml_add(ctx0, cur, inpFF);
+        cur = ggml_add(ctx0, cur, attn_out);
+        offload_func(cur);
+        cur = ggml_add(ctx0, cur, inpL);
         offload_func(cur);
-        ggml_set_name(cur, "inpFF_+_result_w2");
 
         // input for next layer
         inpL = cur;
     }
 
+    cur = inpL;
+
     // norm
     {
-        cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
+        cur = ggml_norm(ctx0, cur, norm_eps);
         offload_func_nr(cur);
-        ggml_set_name(cur, "rms_norm_2");
 
-        // cur = cur*norm(broadcasted)
-        cur = ggml_mul(ctx0, cur, model.norm);
-        // offload_func_nr(cur); // TODO CPU + GPU mirrored backend
+        cur = ggml_add(ctx0,
+                ggml_mul(ctx0, cur, model.output_norm),
+                model.output_norm_b);
         ggml_set_name(cur, "result_norm");
     }
 
-    // lm_head
     cur = ggml_mul_mat(ctx0, model.output, cur);
     ggml_set_name(cur, "result_output");
 
-    // logits -> probs
-    //cur = ggml_soft_max_inplace(ctx0, cur);
-
     ggml_build_forward_expand(gf, cur);
 
-    if (mem_per_token == 0) {
-        mem_per_token = ggml_used_mem(ctx0)/N;
-    }
-
     ggml_free(ctx0);
 
     return gf;
 }
 
+static struct ggml_cgraph * llama_build_graph(
+         llama_context & lctx,
+     const llama_token * tokens,
+           const float * embd,
+                   int   n_tokens,
+                   int   n_past) {
+    const auto & model = lctx.model;
+
+    struct ggml_cgraph * result = NULL;
+
+    switch (model.arch) {
+        case LLM_ARCH_LLAMA:
+            {
+                result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past);
+            } break;
+        case LLM_ARCH_FALCON:
+            {
+                result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past);
+            } break;
+        default:
+            GGML_ASSERT(false);
+    };
+
+    return result;
+}
+
 // evaluate the transformer
 //
 //   - lctx:      llama context
@@ -2077,8 +2817,8 @@ static bool llama_eval_internal(
 
     GGML_ASSERT(!!kv_self.ctx);
 
-    const int64_t n_embd      = hparams.n_embd;
-    const int64_t n_vocab     = hparams.n_vocab;
+    const int64_t n_embd  = hparams.n_embd;
+    const int64_t n_vocab = hparams.n_vocab;
 
     ggml_allocr_reset(lctx.alloc);
 
@@ -2108,11 +2848,11 @@ static bool llama_eval_internal(
     // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
     n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
 
-    struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
+    struct ggml_tensor * res        = gf->nodes[gf->n_nodes - 1];
     struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
 
-    GGML_ASSERT(strcmp(res->name, "result_output") == 0);
-    GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
+    GGML_ASSERT(strcmp(res->name,        "result_output") == 0);
+    GGML_ASSERT(strcmp(embeddings->name, "result_norm")   == 0);
 
 #if GGML_USE_MPI
     const int64_t n_layer = hparams.n_layer;
@@ -2271,13 +3011,7 @@ static std::string llama_unescape_whitespace(const std::string& word) {
     return word;
 }
 
-static size_t utf8_len(char src) {
-    const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
-    uint8_t highbits = static_cast<uint8_t>(src) >> 4;
-    return lookup[highbits];
-}
-
-struct llama_sp_symbol {
+struct llm_symbol {
     using index = int;
     index prev;
     index next;
@@ -2285,33 +3019,35 @@ struct llama_sp_symbol {
     size_t n;
 };
 
-static_assert(std::is_trivially_copyable<llama_sp_symbol>::value, "llama_sp_symbol is not trivially copyable");
+static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
+
+// SPM tokenizer
+// original implementation:
+// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
 
-struct llama_sp_bigram {
+struct llm_bigram_spm {
     struct comparator {
-        bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) {
+        bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
             return (l.score < r.score) || (l.score == r.score && l.left > r.left);
         }
     };
-    using queue_storage = std::vector<llama_sp_bigram>;
-    using queue = std::priority_queue<llama_sp_bigram, queue_storage, comparator>;
-    llama_sp_symbol::index left;
-    llama_sp_symbol::index right;
+    using queue_storage = std::vector<llm_bigram_spm>;
+    using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
+    llm_symbol::index left;
+    llm_symbol::index right;
     float score;
     size_t size;
 };
 
-// original implementation:
-// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
-struct llama_tokenizer {
-    llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
+struct llm_tokenizer_spm {
+    llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
 
     void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
         // split string into utf8 chars
         int index = 0;
         size_t offs = 0;
         while (offs < text.size()) {
-            llama_sp_symbol sym;
+            llm_symbol sym;
             size_t len = utf8_len(text[offs]);
             GGML_ASSERT(offs + len <= text.size());
             sym.text = text.c_str() + offs;
@@ -2320,21 +3056,21 @@ struct llama_tokenizer {
             sym.prev = index - 1;
             sym.next = offs == text.size() ? -1 : index + 1;
             index++;
-            symbols_.emplace_back(sym);
+            symbols.emplace_back(sym);
         }
 
         // seed the work queue with all possible 2-character tokens.
-        for (size_t i = 1; i < symbols_.size(); ++i) {
+        for (size_t i = 1; i < symbols.size(); ++i) {
             try_add_bigram(i - 1, i);
         }
 
         // keep substituting the highest frequency pairs for as long as we can.
-        while (!work_queue_.empty()) {
-            auto bigram = work_queue_.top();
-            work_queue_.pop();
+        while (!work_queue.empty()) {
+            auto bigram = work_queue.top();
+            work_queue.pop();
 
-            auto & left_sym = symbols_[bigram.left];
-            auto & right_sym = symbols_[bigram.right];
+            auto & left_sym = symbols[bigram.left];
+            auto & right_sym = symbols[bigram.right];
 
             // if one of the symbols already got merged, skip it.
             if (left_sym.n == 0 || right_sym.n == 0 ||
@@ -2351,7 +3087,7 @@ struct llama_tokenizer {
             // remove the right sym from the chain
             left_sym.next = right_sym.next;
             if (right_sym.next >= 0) {
-                symbols_[right_sym.next].prev = bigram.left;
+                symbols[right_sym.next].prev = bigram.left;
             }
 
             // find more substitutions
@@ -2359,19 +3095,19 @@ struct llama_tokenizer {
             try_add_bigram(bigram.left, left_sym.next);
         }
 
-        for (int i = 0; i != -1; i = symbols_[i].next) {
-            auto & symbol = symbols_[i];
+        for (int i = 0; i != -1; i = symbols[i].next) {
+            auto & symbol = symbols[i];
             resegment(symbol, output);
         }
     }
 
 private:
-    void resegment(llama_sp_symbol &symbol, std::vector<llama_vocab::id> &output) {
+    void resegment(llm_symbol & symbol, std::vector<llama_vocab::id> & output) {
         auto text = std::string(symbol.text, symbol.n);
-        auto token = vocab_.token_to_id.find(text);
+        auto token = vocab.token_to_id.find(text);
 
         // Do we need to support is_unused?
-        if (token != vocab_.token_to_id.end()) {
+        if (token != vocab.token_to_id.end()) {
             output.push_back((*token).second);
             return;
         }
@@ -2381,14 +3117,14 @@ private:
         if (p == rev_merge.end()) {
             // output any symbols that did not form tokens as bytes.
             for (int j = 0; j < (int)symbol.n; ++j) {
-                llama_vocab::id token_id = llama_byte_to_token(vocab_, symbol.text[j]);
+                llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]);
                 output.push_back(token_id);
             }
             return;
         }
 
-        resegment(symbols_[p->second.first], output);
-        resegment(symbols_[p->second.second], output);
+        resegment(symbols[p->second.first],  output);
+        resegment(symbols[p->second.second], output);
     }
 
     void try_add_bigram(int left, int right) {
@@ -2396,56 +3132,261 @@ private:
             return;
         }
 
-        const std::string text = std::string(symbols_[left].text, symbols_[left].n + symbols_[right].n);
-        auto token = vocab_.token_to_id.find(text);
+        const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
+        auto token = vocab.token_to_id.find(text);
 
-        if (token == vocab_.token_to_id.end()) {
+        if (token == vocab.token_to_id.end()) {
             return;
         }
 
-        if (static_cast<size_t>((*token).second) >= vocab_.id_to_token.size()) {
+        if (static_cast<size_t>((*token).second) >= vocab.id_to_token.size()) {
             return;
         }
 
-        const auto &tok_data = vocab_.id_to_token[(*token).second];
+        const auto & tok_data = vocab.id_to_token[(*token).second];
 
-        llama_sp_bigram bigram;
-        bigram.left = left;
+        llm_bigram_spm bigram;
+        bigram.left  = left;
         bigram.right = right;
         bigram.score = tok_data.score;
-        bigram.size = text.size();
-        work_queue_.push(bigram);
+        bigram.size  = text.size();
+
+        work_queue.push(bigram);
 
         // Do we need to support is_unused?
         rev_merge[text] = std::make_pair(left, right);
     }
 
-    const llama_vocab & vocab_;
-    std::vector<llama_sp_symbol> symbols_;
-    llama_sp_bigram::queue work_queue_;
-    std::map<std::string, std::pair<int, int> > rev_merge;
+    const llama_vocab & vocab;
+
+    std::vector<llm_symbol> symbols;
+    llm_bigram_spm::queue work_queue;
+
+    std::map<std::string, std::pair<int, int>> rev_merge;
+};
+
+// BPE tokenizer
+// adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License]
+// tried to simplify unicode stuff, so most likely does not work 100% correctly!
+
+// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused
+
+struct llm_bigram_bpe {
+    struct comparator {
+        bool operator()(llm_bigram_bpe & l, llm_bigram_bpe & r) {
+            return l.rank > r.rank || (l.rank == r.rank && l.left > r.left);
+        }
+    };
+
+    using queue_storage = std::vector<llm_bigram_bpe>;
+    using queue = std::priority_queue<llm_bigram_bpe, queue_storage, comparator>;
+    llm_symbol::index left;
+    llm_symbol::index right;
+    std::string text;
+    int rank;
+    size_t size;
+};
+
+struct llm_tokenizer_bpe {
+    llm_tokenizer_bpe(const llama_vocab & vocab, bool g2ws): vocab(vocab) { flag_g2ws = g2ws; }
+
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+        int final_prev_index = -1;
+        auto word_collection = bpe_gpt2_preprocess(text);
+
+        symbols_final.clear();
+
+        for (auto & word : word_collection) {
+            work_queue = llm_bigram_bpe::queue();
+            symbols.clear();
+
+            int index = 0;
+            size_t offset = 0;
+
+            while (offset < word.size()) {
+                llm_symbol sym;
+                size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset]));
+                sym.text = word.c_str() + offset;
+                sym.n = 1;
+                sym.n = char_len;
+                offset += sym.n;
+                sym.prev = index - 1;
+                sym.next = offset == word.size() ? -1 : index + 1;
+                index++;
+                symbols.emplace_back(sym);
+            }
+            for (size_t i = 1; i < symbols.size(); ++i) {
+                add_new_bigram(i - 1, i);
+            }
+
+            // build token(s)
+            while (!work_queue.empty()) {
+                auto bigram = work_queue.top();
+                work_queue.pop();
+
+                auto & left_symbol = symbols[bigram.left];
+                auto & right_symbol = symbols[bigram.right];
+
+                if (left_symbol.n == 0 || right_symbol.n == 0) {
+                    continue;
+                }
+                std::string left_token = std::string(left_symbol.text, left_symbol.n);
+                std::string right_token = std::string(right_symbol.text, right_symbol.n);
+                if (left_token + right_token != bigram.text) {
+                    continue;  // Skip this bigram if it's outdated
+                }
+
+                // merge the right sym into the left one
+                left_symbol.n += right_symbol.n;
+                right_symbol.n = 0;
+
+                // remove the right sym from the chain
+                left_symbol.next = right_symbol.next;
+                if (right_symbol.next >= 0) {
+                    symbols[right_symbol.next].prev = bigram.left;
+                }
+
+                add_new_bigram(left_symbol.prev, bigram.left);  // left side of current symbol
+                add_new_bigram(bigram.left, left_symbol.next);  // right side of current symbol
+            }
+
+            // add the fnished tokens to the final list keeping correct order for next and prev
+            for (auto & sym : symbols) {
+                if (sym.n > 0) {
+                    sym.prev = final_prev_index;
+                    sym.next = -1;
+                    if (final_prev_index != -1) {
+                        symbols_final[final_prev_index].next = symbols_final.size();
+                    }
+                    symbols_final.emplace_back(sym);
+                    final_prev_index = symbols_final.size() - 1;
+                }
+            }
+        }
+
+        symbols = symbols_final;
+
+        if (!symbols.empty()) {
+            for (int i = 0; i != -1; i = symbols[i].next) {
+                auto & symbol = symbols[i];
+                if (symbol.n == 0) {
+                    continue;
+                }
+
+                const std::string str = std::string(symbol.text, symbol.n);
+                const auto token = vocab.token_to_id.find(str);
+
+                if (token == vocab.token_to_id.end()) {
+                    for (auto j = str.begin(); j != str.end(); ++j) {
+                        std::string byte_str(1, *j);
+                        auto token_multibyte = vocab.token_to_id.find(byte_str);
+                        if (token_multibyte == vocab.token_to_id.end()) {
+                            fprintf(stderr,"ERROR: byte not found in vocab: '%s'\n", byte_str.c_str());
+                        }
+                        output.push_back((*token_multibyte).second);
+                    }
+                } else {
+                    output.push_back((*token).second);
+                }
+            }
+        }
+    }
+
+private:
+    void add_new_bigram(int left, int right) {
+        if (left == -1 || right == -1) {
+            return;
+        }
+
+        std::string left_token  = std::string(symbols[left].text,  symbols[left].n);
+        std::string right_token = std::string(symbols[right].text, symbols[right].n);
+
+        int rank_found = -1;
+
+        rank_found = vocab.find_bpe_rank(left_token, right_token);
+
+        if (rank_found < 0) {
+            return;
+        }
+
+        llm_bigram_bpe bigram;
+
+        bigram.left  = left;
+        bigram.right = right;
+        bigram.text  = left_token + right_token;
+        bigram.size  = left_token.size() + right_token.size();
+        bigram.rank  = rank_found;
+
+        work_queue.push(bigram);
+    }
+
+    // probably not 100% correct
+    // TODO: this is quite slow - how to make it more efficient?
+    static std::vector<std::string> bpe_gpt2_preprocess(std::string text) {
+        std::vector<std::string> words;
+
+        // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
+        const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
+        const std::regex re(pattern);
+        std::smatch m;
+
+        while (std::regex_search(text, m, re)) {
+            for (auto x : m) {
+                words.push_back(x);
+            }
+            text = m.suffix();
+        }
+
+        return words;
+    }
+
+    bool flag_g2ws = false;
+
+    const llama_vocab & vocab;
+
+    std::vector<llm_symbol> symbols;
+    std::vector<llm_symbol> symbols_final;
+
+    llm_bigram_bpe::queue work_queue;
 };
 
 static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) {
-    llama_tokenizer tokenizer(vocab);
     std::vector<llama_vocab::id> output;
 
     if (raw_text.empty()) {
         return output;
     }
 
-    if (bos) {
-        output.push_back(vocab.special_bos_id);
-    }
+    switch (vocab.type) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            {
+                llm_tokenizer_spm tokenizer(vocab);
 
-    std::string text;
-    if (escape) {
-        text = llama_escape_whitespace(raw_text);
-    } else {
-        text = raw_text;
-    }
+                if (bos) {
+                    output.push_back(vocab.special_bos_id);
+                }
+
+                std::string text;
+                if (escape) {
+                    text = llama_escape_whitespace(raw_text);
+                } else {
+                    text = raw_text;
+                }
+
+                tokenizer.tokenize(text, output);
+            } break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            {
+                llm_tokenizer_bpe tokenizer(vocab, escape);
+
+                if (bos && vocab.special_bos_id != -1) {
+                    output.push_back(vocab.special_bos_id);
+                }
+
+                tokenizer.tokenize(raw_text, output);
+            } break;
+    };
 
-    tokenizer.tokenize(text, output);
     return output;
 }
 
@@ -3449,13 +4390,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         nthread = std::thread::hardware_concurrency();
     }
 
-    std::unique_ptr<llama_model_loader> model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false));
+    std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname_inp, /*use_mmap*/ false));
 
     const size_t align = GGUF_DEFAULT_ALIGNMENT;
     struct gguf_context * ctx_out = gguf_init_empty();
 
     // copy the KV pairs from the input file
-    gguf_set_kv     (ctx_out, model_loader->ctx_gguf);
+    gguf_set_kv     (ctx_out, ml->ctx_gguf);
     gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
     gguf_set_val_u32(ctx_out, "general.file_type", ftype);
 
@@ -3463,8 +4404,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     int n_attention_wv    = 0;
     int n_feed_forward_w2 = 0;
 
-    for (int i = 0; i < model_loader->n_tensors; ++i) {
-        struct ggml_tensor * meta = model_loader->get_tensor_meta(i);
+    for (int i = 0; i < ml->n_tensors; ++i) {
+        struct ggml_tensor * meta = ml->get_tensor_meta(i);
 
         const std::string name = ggml_get_name(meta);
 
@@ -3498,8 +4439,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     std::vector<uint8_t> work;
 
     // populate the original tensors so we get an initial meta data
-    for (int i = 0; i < model_loader->n_tensors; ++i) {
-        struct ggml_tensor * meta = model_loader->get_tensor_meta(i);
+    for (int i = 0; i < ml->n_tensors; ++i) {
+        struct ggml_tensor * meta = ml->get_tensor_meta(i);
         gguf_add_tensor(ctx_out, meta);
     }
 
@@ -3512,17 +4453,17 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     // placeholder for the meta data
     ::zeros(fout, meta_size);
 
-    for (int i = 0; i < model_loader->n_tensors; ++i) {
-        struct ggml_tensor * tensor = model_loader->get_tensor_meta(i);
+    for (int i = 0; i < ml->n_tensors; ++i) {
+        struct ggml_tensor * tensor = ml->get_tensor_meta(i);
 
         const std::string name = ggml_get_name(tensor);
 
         read_data.resize(ggml_nbytes(tensor));
         tensor->data = read_data.data();
-        model_loader->load_data_for(tensor);
+        ml->load_data_for(tensor);
 
         LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
-               ++idx, model_loader->n_tensors,
+               ++idx, ml->n_tensors,
                ggml_get_name(tensor),
                llama_format_tensor_shape(tensor).c_str(),
                ggml_type_name(tensor->type));
@@ -3548,7 +4489,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             new_type = quantized_type;
 #ifdef GGML_USE_K_QUANTS
             // TODO: avoid hardcoded tensor names - use the TN_* constants
-            if (name == TN_OUTPUT) {
+            const auto tn = LLM_TN(ml->get_arch());
+
+            if (name == tn(LLM_TENSOR_OUTPUT, "weight")) {
                 int nx = tensor->ne[0];
                 int ny = tensor->ne[1];
                 if (nx % QK_K == 0 && ny % QK_K == 0) {
@@ -3600,10 +4543,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                 }
             }
             if (convert_incompatible_tensor) {
-                if (name == TN_OUTPUT) {
+                if (name == tn(LLM_TENSOR_OUTPUT, "weight")) {
                     new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing.
                     LLAMA_LOG_WARN("F16 will be used for this tensor instead.\n");
-                } else if (name == TN_TOKEN_EMBD) {
+                } else if (name == tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
                     new_type = GGML_TYPE_Q4_0; //fall back to Q4_0 instead of just failing.
                     LLAMA_LOG_WARN("Q4_0 will be used for this tensor instead.\n");
                 } else {
@@ -3785,28 +4728,28 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
     }
 
     // load base model
-    std::unique_ptr<llama_model_loader> model_loader;
+    std::unique_ptr<llama_model_loader> ml;
     ggml_context * base_ctx = NULL;
     std::vector<uint8_t> base_buf;
     if (path_base_model) {
         LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
-        model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true));
+        ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true));
 
         size_t ctx_size;
         size_t mmapped_size;
-        model_loader->calc_sizes(ctx_size, mmapped_size);
+        ml->calc_sizes(ctx_size, mmapped_size);
         base_buf.resize(ctx_size);
 
         ggml_init_params base_params;
         base_params.mem_size   = base_buf.size();
         base_params.mem_buffer = base_buf.data();
-        base_params.no_alloc   = model_loader->use_mmap;
+        base_params.no_alloc   = ml->use_mmap;
 
         base_ctx = ggml_init(base_params);
 
         // maybe this should in llama_model_loader
-        if (model_loader->use_mmap) {
-            model_loader->mapping.reset(new llama_mmap(&model_loader->file, /* prefetch */ 0, ggml_is_numa()));
+        if (ml->use_mmap) {
+            ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, ggml_is_numa()));
         }
     }
 
@@ -3910,18 +4853,19 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
 #endif // GGML_USE_CUBLAS
 
             ggml_tensor * base_t;
-            if (model_loader) {
-                struct gguf_context * ctx_gguf = model_loader->ctx_gguf;
+            if (ml) {
+                struct gguf_context * ctx_gguf = ml->ctx_gguf;
 
                 // load from base model
                 if (gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) {
+                    // TODO: throw
                     LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
                     return 1;
                 }
 
                 // TODO: not tested!! maybe not working!
-                base_t = model_loader->create_tensor(base_ctx, base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU);
-                model_loader->load_data_for(base_t);
+                base_t = ml->create_tensor(base_ctx, base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU);
+                ml->load_data_for(base_t);
             } else {
                 base_t = dest_t;
             }
@@ -4096,7 +5040,23 @@ struct llama_model * llama_load_model_from_file(
 
     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
-    if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers,
+    unsigned cur_percentage = 0;
+    if (params.progress_callback == NULL) {
+        params.progress_callback_user_data = &cur_percentage;
+        params.progress_callback = [](float progress, void * ctx) {
+            unsigned * cur_percentage_p = (unsigned *) ctx;
+            unsigned percentage = (unsigned) (100 * progress);
+            while (percentage > *cur_percentage_p) {
+                *cur_percentage_p = percentage;
+                LLAMA_LOG_INFO(".");
+                if (percentage >= 100) {
+                    LLAMA_LOG_INFO("\n");
+                }
+            }
+        };
+    }
+
+    if (!llama_model_load(path_model, *model, params.n_ctx, params.n_batch, params.n_gpu_layers,
                 params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale,
                 params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only,
                 params.progress_callback, params.progress_callback_user_data)) {
@@ -4126,22 +5086,6 @@ struct llama_context * llama_new_context_with_model(
         params.seed = time(NULL);
     }
 
-    unsigned cur_percentage = 0;
-    if (params.progress_callback == NULL) {
-        params.progress_callback_user_data = &cur_percentage;
-        params.progress_callback = [](float progress, void * ctx) {
-            unsigned * cur_percentage_p = (unsigned *) ctx;
-            unsigned percentage = (unsigned) (100 * progress);
-            while (percentage > *cur_percentage_p) {
-                *cur_percentage_p = percentage;
-                LLAMA_LOG_INFO(".");
-                if (percentage >= 100) {
-                    LLAMA_LOG_INFO("\n");
-                }
-            }
-        };
-    }
-
     ctx->rng = std::mt19937(params.seed);
     ctx->logits_all = params.logits_all;
 
@@ -4279,13 +5223,14 @@ struct llama_context * llama_new_context_with_model(
 struct llama_context * llama_init_from_file(
                              const char * path_model,
             struct llama_context_params   params) {
-
     struct llama_model * model = llama_load_model_from_file(path_model, params);
     if (!model) {
         return nullptr;
     }
+
     struct llama_context * ctx = llama_new_context_with_model(model, params);
     ctx->model_owner = true;
+
     return ctx;
 }
 
@@ -4305,6 +5250,10 @@ int llama_n_embd(const struct llama_context * ctx) {
     return ctx->model.hparams.n_embd;
 }
 
+enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) {
+    return ctx->model.vocab.type;
+}
+
 int llama_model_n_vocab(const struct llama_model * model) {
     return model->vocab.id_to_token.size();
 }
@@ -4318,7 +5267,10 @@ int llama_model_n_embd(const struct llama_model * model) {
 }
 
 int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size) {
-    return snprintf(buf, buf_size, "LLaMA %s %s", llama_model_type_name(model->type), llama_model_ftype_name(model->ftype).c_str());
+    return snprintf(buf, buf_size, "%s %s %s",
+            model->name.c_str(),
+            llama_model_type_name(model->type),
+            llama_model_ftype_name(model->ftype).c_str());
 }
 
 int llama_model_quantize(
@@ -4839,26 +5791,6 @@ int llama_tokenize(
     return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos);
 }
 
-int llama_tokenize_bpe(
-        struct llama_context * ctx,
-                  const char * text,
-                 llama_token * tokens,
-                         int   n_max_tokens,
-                        bool   add_bos) {
-    auto res = llama_tokenize_internal(ctx->model.vocab, text, add_bos, false);
-
-    if (n_max_tokens < (int) res.size()) {
-        LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
-        return -((int) res.size());
-    }
-
-    for (size_t i = 0; i < res.size(); i++) {
-        tokens[i] = res[i];
-    }
-
-    return res.size();
-}
-
 int llama_tokenize_with_model(
     const struct llama_model * model,
                   const char * text,
@@ -4884,18 +5816,6 @@ int llama_token_to_str(const struct llama_context * ctx, llama_token token, char
     return llama_token_to_str_with_model(&ctx->model, token, buf, length);
 }
 
-int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * buf, int length) {
-    if (0 <= token && token < llama_model_n_vocab(&ctx->model)) {
-        std::string result = ctx->model.vocab.id_to_token[token].text;
-        if (length < (int) result.length()) {
-            return -result.length();
-        }
-        memcpy(buf, result.c_str(), result.length());
-        return result.length();
-    }
-    return 0;
-}
-
 // does not write null-terminator to str
 int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * buf, int length) {
     if (0 <= token && token < llama_model_n_vocab(model)) {
diff --git a/llama.h b/llama.h
index 7ce478d5452a74cea0c898ad9bcc3fc20f7a5028..4e7638c042de970ac6eff9a15c6c22337fffee02 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -247,6 +247,8 @@ extern "C" {
     LLAMA_API int llama_n_ctx  (const struct llama_context * ctx);
     LLAMA_API int llama_n_embd (const struct llama_context * ctx);
 
+    LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
+
     LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
     LLAMA_API int llama_model_n_ctx  (const struct llama_model * model);
     LLAMA_API int llama_model_n_embd (const struct llama_model * model);
@@ -368,13 +370,6 @@ extern "C" {
                              int   n_max_tokens,
                             bool   add_bos);
 
-    LLAMA_API int llama_tokenize_bpe(
-            struct llama_context * ctx,
-                      const char * text,
-                     llama_token * tokens,
-                             int   n_max_tokens,
-                            bool   add_bos);
-
     LLAMA_API int llama_tokenize_with_model(
         const struct llama_model * model,
                       const char * text,
@@ -390,12 +385,6 @@ extern "C" {
                                   char * buf,
                                   int    length);
 
-    LLAMA_API int llama_token_to_str_bpe(
-            const struct llama_context * ctx,
-                           llama_token   token,
-                                  char * buf,
-                                  int    length);
-
     LLAMA_API int llama_token_to_str_with_model(
               const struct llama_model * model,
                            llama_token   token,
index 4ccefe9322322d6a56c7599cf52c45c07d6e6662..2afaf86b114504bb8d52d18dfd3d6f0d729c0af3 100644 (file)
@@ -28,7 +28,8 @@ llama_build_and_test_executable(test-sampling.cpp)
 llama_build_executable(test-tokenizer-0.cpp)
 llama_test_executable (test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
 llama_build_executable(test-tokenizer-1.cpp)
-llama_test_executable (test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
+# test-tokenizer-1 requires a BPE vocab. re-enable when we have one.
+#llama_test_executable (test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
 #llama_test_executable(test-tokenizer-1.aquila test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
 llama_build_and_test_executable(test-grammar-parser.cpp)
 llama_build_and_test_executable(test-llama-grammar.cpp)
index 993d17f1833d30779671defd16deb3aa6b502fe9..bd607d12bb1cd8dd9293aa1959fc381749a2153b 100644 (file)
@@ -67,11 +67,13 @@ int main(int argc, char **argv) {
         }
     }
 
+    GGML_ASSERT(llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_BPE);
+
     const int n_vocab = llama_n_vocab(ctx);
 
     for (int i = 0; i < n_vocab; ++i) {
-        std::string forward = llama_token_to_str_bpe(ctx, i);
-        std::vector<llama_token> tokens = llama_tokenize_bpe(ctx, forward, false);
+        std::string forward = llama_token_to_str(ctx, i);
+        std::vector<llama_token> tokens = llama_tokenize(ctx, forward, false);
         if (tokens.size() == 1) {
             if (i != tokens[0]) {
                 std::string backward = llama_token_to_str(ctx, tokens[0]);
@@ -79,16 +81,6 @@ int main(int argc, char **argv) {
                     __func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str());
                 return 2;
             }
-        } else {
-            llama_token_type type = llama_token_get_type(ctx, i);
-            if (type == LLAMA_TOKEN_TYPE_UNKNOWN || type == LLAMA_TOKEN_TYPE_CONTROL || type == LLAMA_TOKEN_TYPE_BYTE) {
-                fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n",
-                    __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str());
-            } else {
-                fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n",
-                    __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str());
-                return 2;
-            }
         }
     }