]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Tue, 20 Feb 2024 10:09:57 +0000 (12:09 +0200)
committerGeorgi Gerganov <redacted>
Tue, 20 Feb 2024 10:09:57 +0000 (12:09 +0200)
examples/talk-llama/llama.cpp
examples/talk-llama/llama.h
examples/talk-llama/talk-llama.cpp
examples/talk-llama/unicode.h

index a5b873a7bf144fb0b5657fafdabcd71323b10ac8..5de07dfa999a7cb8417ee0c3ddeef639428b81f8 100644 (file)
@@ -197,6 +197,7 @@ enum llm_arch {
     LLM_ARCH_PERSIMMON,
     LLM_ARCH_REFACT,
     LLM_ARCH_BERT,
+    LLM_ARCH_NOMIC_BERT,
     LLM_ARCH_BLOOM,
     LLM_ARCH_STABLELM,
     LLM_ARCH_QWEN,
@@ -211,27 +212,28 @@ enum llm_arch {
 };
 
 static std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
-    { LLM_ARCH_LLAMA,           "llama"     },
-    { LLM_ARCH_FALCON,          "falcon"    },
-    { LLM_ARCH_GPT2,            "gpt2"      },
-    { LLM_ARCH_GPTJ,            "gptj"      },
-    { LLM_ARCH_GPTNEOX,         "gptneox"   },
-    { LLM_ARCH_MPT,             "mpt"       },
-    { LLM_ARCH_BAICHUAN,        "baichuan"  },
-    { LLM_ARCH_STARCODER,       "starcoder" },
-    { LLM_ARCH_PERSIMMON,       "persimmon" },
-    { LLM_ARCH_REFACT,          "refact"    },
-    { LLM_ARCH_BERT,            "bert"      },
-    { LLM_ARCH_BLOOM,           "bloom"     },
-    { LLM_ARCH_STABLELM,        "stablelm"  },
-    { LLM_ARCH_QWEN,            "qwen"      },
-    { LLM_ARCH_QWEN2,           "qwen2"     },
-    { LLM_ARCH_PHI2,            "phi2"      },
-    { LLM_ARCH_PLAMO,           "plamo"     },
-    { LLM_ARCH_CODESHELL,       "codeshell" },
-    { LLM_ARCH_ORION,           "orion"     },
-    { LLM_ARCH_INTERNLM2,       "internlm2" },
-    { LLM_ARCH_MINICPM,         "minicpm"   },
+    { LLM_ARCH_LLAMA,           "llama"      },
+    { LLM_ARCH_FALCON,          "falcon"     },
+    { LLM_ARCH_GPT2,            "gpt2"       },
+    { LLM_ARCH_GPTJ,            "gptj"       },
+    { LLM_ARCH_GPTNEOX,         "gptneox"    },
+    { LLM_ARCH_MPT,             "mpt"        },
+    { LLM_ARCH_BAICHUAN,        "baichuan"   },
+    { LLM_ARCH_STARCODER,       "starcoder"  },
+    { LLM_ARCH_PERSIMMON,       "persimmon"  },
+    { LLM_ARCH_REFACT,          "refact"     },
+    { LLM_ARCH_BERT,            "bert"       },
+    { LLM_ARCH_NOMIC_BERT,      "nomic-bert" },
+    { LLM_ARCH_BLOOM,           "bloom"      },
+    { LLM_ARCH_STABLELM,        "stablelm"   },
+    { LLM_ARCH_QWEN,            "qwen"       },
+    { LLM_ARCH_QWEN2,           "qwen2"      },
+    { LLM_ARCH_PHI2,            "phi2"       },
+    { LLM_ARCH_PLAMO,           "plamo"      },
+    { LLM_ARCH_CODESHELL,       "codeshell"  },
+    { LLM_ARCH_ORION,           "orion"      },
+    { LLM_ARCH_INTERNLM2,       "internlm2"  },
+    { LLM_ARCH_MINICPM,         "minicpm"    },
 };
 
 enum llm_kv {
@@ -254,6 +256,7 @@ enum llm_kv {
     LLM_KV_TENSOR_DATA_LAYOUT,
     LLM_KV_EXPERT_COUNT,
     LLM_KV_EXPERT_USED_COUNT,
+    LLM_KV_POOLING_TYPE,
 
     LLM_KV_ATTENTION_HEAD_COUNT,
     LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -311,6 +314,7 @@ static std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TENSOR_DATA_LAYOUT,            "%s.tensor_data_layout"    },
     { LLM_KV_EXPERT_COUNT,                  "%s.expert_count"          },
     { LLM_KV_EXPERT_USED_COUNT,             "%s.expert_used_count"     },
+    { LLM_KV_POOLING_TYPE ,                 "%s.pooling_type"          },
 
     { LLM_KV_ATTENTION_HEAD_COUNT,          "%s.attention.head_count"             },
     { LLM_KV_ATTENTION_HEAD_COUNT_KV,       "%s.attention.head_count_kv"          },
@@ -373,6 +377,7 @@ enum llm_tensor {
     LLM_TENSOR_ATTN_OUT,
     LLM_TENSOR_ATTN_NORM,
     LLM_TENSOR_ATTN_NORM_2,
+    LLM_TENSOR_ATTN_OUT_NORM,
     LLM_TENSOR_ATTN_ROT_EMBD,
     LLM_TENSOR_FFN_GATE_INP,
     LLM_TENSOR_FFN_NORM,
@@ -385,6 +390,7 @@ enum llm_tensor {
     LLM_TENSOR_FFN_UP_EXP,
     LLM_TENSOR_ATTN_Q_NORM,
     LLM_TENSOR_ATTN_K_NORM,
+    LLM_TENSOR_LAYER_OUT_NORM,
 };
 
 static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
@@ -550,12 +556,27 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
             { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
             { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
             { LLM_TENSOR_POS_EMBD,        "position_embd" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_output_norm" },
+            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
             { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
             { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
             { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
             { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.layer_output_norm" },
+            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
+    {
+        LLM_ARCH_NOMIC_BERT,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
+            { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
+            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
             { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
@@ -772,22 +793,37 @@ struct LLM_TN {
     llm_arch arch;
 
     std::string operator()(llm_tensor tensor) const {
+        if (LLM_TENSOR_NAMES[arch].find(tensor) == LLM_TENSOR_NAMES[arch].end()) {
+            return "__missing__";
+        }
         return LLM_TENSOR_NAMES[arch].at(tensor);
     }
 
     std::string operator()(llm_tensor tensor, const std::string & suffix) const {
+        if (LLM_TENSOR_NAMES[arch].find(tensor) == LLM_TENSOR_NAMES[arch].end()) {
+            return "__missing__";
+        }
         return LLM_TENSOR_NAMES[arch].at(tensor) + "." + suffix;
     }
 
     std::string operator()(llm_tensor tensor, int bid) const {
+        if (LLM_TENSOR_NAMES[arch].find(tensor) == LLM_TENSOR_NAMES[arch].end()) {
+            return "__missing__";
+        }
         return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid);
     }
 
     std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const {
+        if (LLM_TENSOR_NAMES[arch].find(tensor) == LLM_TENSOR_NAMES[arch].end()) {
+            return "__missing__";
+        }
         return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix;
     }
 
     std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const {
+        if (LLM_TENSOR_NAMES[arch].find(tensor) == LLM_TENSOR_NAMES[arch].end()) {
+            return "__missing__";
+        }
         return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid, xid) + "." + suffix;
     }
 };
@@ -998,7 +1034,7 @@ struct llama_mmap {
         int fd = fileno(file->fp);
         int flags = MAP_SHARED;
         // prefetch/readahead impairs performance on NUMA systems
-        if (numa) { prefetch = 0; }
+        if (numa)  { prefetch = 0; }
 #ifdef __linux__
         // advise the kernel to read the file sequentially (increases readahead)
         if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) {
@@ -1468,6 +1504,7 @@ enum e_model {
     MODEL_22M,
     MODEL_33M,
     MODEL_109M,
+    MODEL_137M,
     MODEL_335M,
     MODEL_0_5B,
     MODEL_1B,
@@ -1520,11 +1557,13 @@ struct llama_hparams {
     uint32_t n_yarn_orig_ctx;
     int32_t  rope_scaling_type_train;
 
-    float f_clamp_kqv;
-    float f_max_alibi_bias;
+    float f_clamp_kqv      = 0.0f;
+    float f_max_alibi_bias = 0.0f;
 
     bool causal_attn = true;
+    bool need_kq_pos = false;
 
+    uint32_t pooling_type = LLAMA_POOLING_NONE;
 
     bool operator!=(const llama_hparams & other) const {
         if (this->vocab_only    != other.vocab_only)    return true;
@@ -1586,6 +1625,7 @@ struct llama_cparams {
 
     bool mul_mat_q;
     bool offload_kqv;
+    bool do_pooling;
 
     ggml_backend_sched_eval_callback cb_eval;
     void * cb_eval_user_data;
@@ -1601,6 +1641,8 @@ struct llama_layer {
     struct ggml_tensor * attn_q_norm_b;
     struct ggml_tensor * attn_k_norm;
     struct ggml_tensor * attn_k_norm_b;
+    struct ggml_tensor * attn_out_norm;
+    struct ggml_tensor * attn_out_norm_b;
 
     // attention
     struct ggml_tensor * wq;
@@ -1619,6 +1661,8 @@ struct llama_layer {
     // normalization
     struct ggml_tensor * ffn_norm;
     struct ggml_tensor * ffn_norm_b;
+    struct ggml_tensor * layer_out_norm;
+    struct ggml_tensor * layer_out_norm_b;
 
     // ff
     struct ggml_tensor * ffn_gate; // w1
@@ -1880,8 +1924,10 @@ struct llama_context {
     struct ggml_tensor * inp_embd;      // F32 [n_embd, n_batch]
     struct ggml_tensor * inp_pos;       // I32 [n_batch]
     struct ggml_tensor * inp_KQ_mask;   // F32 [n_ctx, n_batch]
+    struct ggml_tensor * inp_KQ_pos;    // F32 [n_ctx]
     struct ggml_tensor * inp_K_shift;   // I32 [n_ctx]
-    struct ggml_tensor * inp_sum;       // F32 [1, n_batch]
+    struct ggml_tensor * inp_mean;      // F32 [n_batch, n_batch]
+    struct ggml_tensor * inp_cls;       // I32 [n_batch]
 
 #ifdef GGML_USE_MPI
     ggml_mpi_context * ctx_mpi = NULL;
@@ -2480,6 +2526,7 @@ struct llama_model_loader {
                 case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
                 case GGML_TYPE_IQ2_XS:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS;  break;
                 case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
+                case GGML_TYPE_IQ1_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ1_S;   break;
                 default:
                     {
                         LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
@@ -2829,6 +2876,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
         case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
         case LLAMA_FTYPE_MOSTLY_Q3_K_XS:return "Q3_K - Extra small";
         case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw";
+        case LLAMA_FTYPE_MOSTLY_IQ1_S  :return "IQ1_S - 1.5625 bpw";
 
         default: return "unknown, may not work";
     }
@@ -2836,6 +2884,11 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
 
 static const char * llama_model_type_name(e_model type) {
     switch (type) {
+        case MODEL_22M:    return "22M";
+        case MODEL_33M:    return "33M";
+        case MODEL_109M:   return "109M";
+        case MODEL_137M:   return "137M";
+        case MODEL_0_5B:   return "0.5B";
         case MODEL_1B:     return "1B";
         case MODEL_2B:     return "2B";
         case MODEL_3B:     return "3B";
@@ -3005,6 +3058,11 @@ static void llm_load_hparams(
                     case 40: model.type = e_model::MODEL_13B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
+
+                if (model.type == e_model::MODEL_13B) {
+                    // TODO: become GGUF KV parameter
+                    hparams.f_max_alibi_bias = 8.0f;
+                }
             } break;
         case LLM_ARCH_STARCODER:
             {
@@ -3032,12 +3090,16 @@ static void llm_load_hparams(
                     case 32: model.type = e_model::MODEL_1B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
+
+                // TODO: become GGUF KV parameter
+                hparams.f_max_alibi_bias = 8.0f;
             } break;
         case LLM_ARCH_BERT:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
                 ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
                 ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
+                ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
 
                 switch (hparams.n_layer) {
                     case 3:
@@ -3053,6 +3115,17 @@ static void llm_load_hparams(
                         model.type = e_model::MODEL_335M; break; // bge-large
                 }
             } break;
+        case LLM_ARCH_NOMIC_BERT:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
+                ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
+                ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
+
+                if (hparams.n_layer == 12 && hparams.n_embd == 768) {
+                    model.type = e_model::MODEL_137M;
+                }
+            } break;
         case LLM_ARCH_BLOOM:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -3065,11 +3138,12 @@ static void llm_load_hparams(
                             case 4096: model.type = e_model::MODEL_7B; break;
                         } break;
                 }
+
+                // TODO: become GGUF KV parameter
+                hparams.f_max_alibi_bias = 8.0f;
             } break;
         case LLM_ARCH_MPT:
             {
-                hparams.f_clamp_kqv = 0.0f;
-
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,  hparams.f_norm_eps);
                 ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV,      hparams.f_clamp_kqv, false);
                 ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
@@ -3171,6 +3245,10 @@ static void llm_load_hparams(
     }
 
     model.ftype = ml.ftype;
+
+    if (hparams.f_max_alibi_bias > 0.0f) {
+        hparams.need_kq_pos = true;
+    }
 }
 
 // TODO: This should probably be in llama.h
@@ -3294,7 +3372,12 @@ static void llm_load_vocab(
 
     // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
     if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
-        vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
+        try {
+            vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
+        } catch (const std::exception & e) {
+            LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what());
+            vocab.linefeed_id = vocab.special_pad_id;
+        }
     } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
         vocab.linefeed_id = vocab.special_pad_id;
     } else {
@@ -3850,10 +3933,14 @@ static bool llm_load_tensors(
                     }
                 } break;
             case LLM_ARCH_BERT:
+            case LLM_ARCH_NOMIC_BERT:
                 {
-                    model.tok_embd   = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
-                    model.type_embd  = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES,     "weight"), {n_embd, n_vocab_type});
-                    model.pos_embd   = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,        "weight"), {n_embd, hparams.n_ctx_train});
+                    model.tok_embd     = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab});
+                    model.type_embd    = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type});
+                    if (model.arch == LLM_ARCH_BERT) {
+                        model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, hparams.n_ctx_train});
+                    }
+
                     model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
                     model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd});
 
@@ -3863,29 +3950,38 @@ static bool llm_load_tensors(
 
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+                        if (model.arch == LLM_ARCH_BERT) {
+                            layer.wq   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                            layer.bq   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd});
 
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+                            layer.wk   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                            layer.bk   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa});
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd});
+                            layer.wv   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                            layer.bv   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa});
+                        } else {
+                            layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
+                        }
 
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i),   {n_embd_gqa});
+                        layer.wo              = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd});
 
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i),   {n_embd_gqa});
+                        layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
+                        layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd});
 
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+                        layer.ffn_up          = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff});
+                        layer.ffn_down        = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd});
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
+                        if (model.arch == LLM_ARCH_BERT) {
+                            layer.bo         = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+                            layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
 
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                            layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                        } else {
+                            layer.ffn_gate   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
+                        }
+
+                        layer.layer_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
+                        layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i),   {n_embd});
                     }
                 } break;
             case LLM_ARCH_BLOOM:
@@ -4364,9 +4460,21 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
 
         model.hparams.vocab_only = params.vocab_only;
 
-        llm_load_arch   (ml, model);
-        llm_load_hparams(ml, model);
-        llm_load_vocab  (ml, model);
+        try {
+            llm_load_arch(ml, model);
+        } catch(const std::exception & e) {
+            throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
+        }
+        try {
+            llm_load_hparams(ml, model);
+        } catch(const std::exception & e) {
+            throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
+        }
+        try {
+            llm_load_vocab(ml, model);
+        } catch(const std::exception & e) {
+            throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
+        }
 
         llm_load_print_meta(ml, model);
 
@@ -4683,10 +4791,10 @@ static struct ggml_tensor * llm_build_kqv(
          struct ggml_tensor * wo_b,
          struct ggml_tensor * q_cur,
          struct ggml_tensor * kq_mask,
+         struct ggml_tensor * kq_pos,
                     int64_t   n_ctx,
                     int32_t   n_tokens,
                     int32_t   n_kv,
-                    float     max_alibi_bias,
                     float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
@@ -4716,26 +4824,26 @@ static struct ggml_tensor * llm_build_kqv(
         ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
     }
 
-    if (max_alibi_bias > 0.0f) {
-        // temporary branch until we figure out how to handle ggml_alibi through ggml_add
+#if defined(GGML_USE_VULKAN) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_SYCL)
+#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Vulkan, Kompute, and SYCL")
+#pragma message("      Falling back to ggml_alibi(). Will become an error in Mar 2024")
+#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5488")
+    if (hparams.f_max_alibi_bias > 0.0f) {
         kq = ggml_scale(ctx, kq, kq_scale);
         cb(kq, "kq_scaled", il);
 
-        if (max_alibi_bias > 0.0f) {
-            // TODO: n_head or n_head_kv
-            // TODO: K-shift is likely not working
-            // TODO: change to ggml_add
-            kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias);
-            cb(kq, "kq_scaled_alibi", il);
-        }
+        kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias);
+        cb(kq, "kq_scaled_alibi", il);
 
         kq = ggml_add(ctx, kq, kq_mask);
         cb(kq, "kq_masked", il);
 
         kq = ggml_soft_max(ctx, kq);
         cb(kq, "kq_soft_max", il);
-    } else {
-        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
+    } else
+#endif
+    {
+        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias);
         cb(kq, "kq_soft_max_ext", il);
     }
 
@@ -4783,11 +4891,11 @@ static struct ggml_tensor * llm_build_kv(
          struct ggml_tensor * v_cur,
          struct ggml_tensor * q_cur,
          struct ggml_tensor * kq_mask,
+         struct ggml_tensor * kq_pos,
                     int64_t   n_ctx,
                     int32_t   n_tokens,
                     int32_t   kv_head,
                     int32_t   n_kv,
-                    float     max_alibi_bias,
                     float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
@@ -4801,9 +4909,8 @@ static struct ggml_tensor * llm_build_kv(
     llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
 
     struct ggml_tensor * cur;
-    cur  = llm_build_kqv(ctx, model, hparams, kv, graph,
-            wo, wo_b,
-            q_cur, kq_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, kq_scale, cb, il);
+    cur  = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
+            q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
     cb(cur, "kqv_out", il);
 
     return cur;
@@ -4844,7 +4951,7 @@ struct llm_build_context {
     const int32_t n_orig_ctx;
 
     const bool do_rope_shift;
-    const bool causal_attn;
+    const uint32_t pooling_type;
 
     const llm_build_cb & cb;
 
@@ -4888,7 +4995,7 @@ struct llm_build_context {
         kv_head          (worst_case ? n_ctx - n_tokens : kv_self.head),
         n_orig_ctx       (cparams.n_yarn_orig_ctx),
         do_rope_shift    (worst_case || kv_self.has_shift),
-        causal_attn      (hparams.causal_attn),
+        pooling_type     (cparams.do_pooling ? hparams.pooling_type : (uint32_t)LLAMA_POOLING_NONE),
         cb               (cb),
         buf_compute_meta (lctx.buf_compute_meta) {
             // all initializations should be done in init()
@@ -4971,7 +5078,7 @@ struct llm_build_context {
                 }
 
                 Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
                     hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -4986,7 +5093,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5116,6 +5223,10 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
+        // positions of the tokens in the KV cache
+        struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
+        cb(KQ_pos, "KQ_pos", -1);
+
         // shift the entire K-cache if needed
         if (do_rope_shift) {
             llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
@@ -5164,12 +5275,9 @@ struct llm_build_context {
                 cb(Kcur, "Kcur", il);
 
 
-                // apply ALiBi for 13B model
-                const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
-
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5293,7 +5401,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5392,7 +5500,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5597,7 +5705,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Q, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5659,6 +5767,10 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
+        // positions of the tokens in the KV cache
+        struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
+        cb(KQ_pos, "KQ_pos", -1);
+
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -5686,7 +5798,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5736,22 +5848,27 @@ struct llm_build_context {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
         // get input vectors with right size
+        const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type);
         struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0);
+        struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0);
+        struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0);
 
         // construct input embeddings (token, type, position)
         inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
+
         // token types are hardcoded to zero ("Sentence A")
         struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
         inpL = ggml_add(ctx0, inpL, type_row0);
-        inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
+        if (model.arch == LLM_ARCH_BERT) {
+            inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
+        }
         cb(inpL, "inp_embd", -1);
 
         // embed layer norm
@@ -5767,7 +5884,7 @@ struct llm_build_context {
             struct ggml_tensor * cur = inpL;
 
             // self-attention
-            {
+            if (model.arch == LLM_ARCH_BERT) {
                 struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
@@ -5782,7 +5899,38 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                cb(cur, "kqv_out", il);
+            } else {
+                // compute Q and K and RoPE them
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5790,25 +5938,34 @@ struct llm_build_context {
             cur = ggml_add(ctx0, cur, inpL);
 
             // attention layer norm
-            cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, cb, il);
+            cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_out_norm, model.layers[il].attn_out_norm_b, LLM_NORM, cb, il);
 
             struct ggml_tensor * ffn_inp = cur;
             cb(ffn_inp, "ffn_inp", il);
 
             // feed-forward network
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                    NULL,                      NULL,
-                    model.layers[il].ffn_down, model.layers[il].ffn_down_b,
-                    NULL,
-                    LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+            if (model.arch == LLM_ARCH_BERT) {
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
+                        NULL,                      NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+            } else {
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   NULL,
+                        model.layers[il].ffn_gate, NULL,
+                        model.layers[il].ffn_down, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            }
             cb(cur, "ffn_out", il);
 
             // attentions bypass the intermediate layer
             cur = ggml_add(ctx0, cur, ffn_inp);
 
             // output layer norm
-            cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il);
+            cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].layer_out_norm, model.layers[il].layer_out_norm_b, LLM_NORM, cb, il);
 
             // input for next layer
             inpL = cur;
@@ -5817,9 +5974,15 @@ struct llm_build_context {
         // final output
         cur = inpL;
 
-        // pooling
-        cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
-        cb(cur, "result_embed", -1);
+        // pooling layer
+        if (pooling_type == LLAMA_POOLING_MEAN) {
+            cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
+        } else if (pooling_type == LLAMA_POOLING_CLS) {
+            cur = ggml_get_rows(ctx0, cur, inp_cls);
+        } else {
+            GGML_ASSERT(pooling_type == LLAMA_POOLING_NONE && "Invalid pooling type");
+        }
+        cb(cur, "result_embd", -1);
 
         ggml_build_forward_expand(gf, cur);
 
@@ -5843,6 +6006,10 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
+        // positions of the tokens in the KV cache
+        struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
+        cb(KQ_pos, "KQ_pos", -1);
+
         inpL = llm_build_norm(ctx0, inpL, hparams,
                 model.tok_norm,
                 model.tok_norm_b,
@@ -5876,7 +6043,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5936,6 +6103,10 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
+        // positions of the tokens in the KV cache
+        struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
+        cb(KQ_pos, "KQ_pos", -1);
+
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * attn_norm;
 
@@ -5969,7 +6140,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6091,7 +6262,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6206,7 +6377,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6327,7 +6498,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6454,7 +6625,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6557,7 +6728,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
             struct ggml_tensor * sa_out = cur;
@@ -6656,7 +6827,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6765,7 +6936,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6883,7 +7054,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -7002,7 +7173,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -7134,7 +7305,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -7249,6 +7420,7 @@ static struct ggml_cgraph * llama_build_graph(
                 result = llm.build_refact();
             } break;
         case LLM_ARCH_BERT:
+        case LLM_ARCH_NOMIC_BERT:
             {
                 result = llm.build_bert();
             } break;
@@ -7352,7 +7524,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
 
                 for (int i = 0; i < n_kv; ++i) {
                     float f;
-                    if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
+                    if (!lctx.kv_self.cells[i].has_seq_id(seq_id) ||
+                        (hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
                         f = -INFINITY;
                     } else {
                         f = 0;
@@ -7363,13 +7536,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
+    if (hparams.need_kq_pos) {
+        const int64_t n_kv = kv_self.n;
 
-    {
-        assert(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
-        float * data = (float *) lctx.inp_sum->data;
+        assert(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
+
+        float * data = (float *) lctx.inp_KQ_pos->data;
 
-        for (int i = 0; i < batch.n_tokens; ++i) {
-            data[i] = 1.0f/float(batch.n_tokens);
+        for (int i = 0; i < n_kv; ++i) {
+            data[i] = float(lctx.kv_self.cells[i].pos);
         }
     }
 
@@ -7384,6 +7559,49 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             data[i] = lctx.kv_self.cells[i].delta;
         }
     }
+
+    if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_MEAN) {
+        const int64_t n_tokens = batch.n_tokens;
+
+        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
+        float * data = (float *) lctx.inp_mean->data;
+
+        memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
+
+        std::vector<uint64_t> sum(n_tokens, 0);
+        for (int i = 0; i < n_tokens; ++i) {
+            const llama_seq_id seq_id = batch.seq_id[i][0];
+            sum[seq_id] += 1;
+        }
+
+        std::vector<float> div(n_tokens, 0.0f);
+        for (int i = 0; i < n_tokens; ++i) {
+            const uint64_t s = sum[i];
+            if (s > 0) {
+                div[i] = 1.0f/float(s);
+            }
+        }
+
+        for (int i = 0; i < n_tokens; ++i) {
+            const llama_seq_id seq_id = batch.seq_id[i][0];
+            data[seq_id*n_tokens + i] = div[seq_id];
+        }
+    }
+
+    if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_CLS) {
+        const int64_t n_tokens = batch.n_tokens;
+
+        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
+        uint32_t * data = (uint32_t *) lctx.inp_cls->data;
+
+        for (int i = 0; i < n_tokens; ++i) {
+            const llama_seq_id seq_id = batch.seq_id[i][0];
+            const llama_pos pos = batch.pos[i];
+            if (pos == 0) {
+                data[seq_id] = i;
+            }
+        }
+    }
 }
 
 // decode a batch of tokens by evaluating the transformer
@@ -7495,7 +7713,7 @@ static int llama_decode_internal(
             embeddings = gf->nodes[gf->n_nodes - 3];
             GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
         }
-    } else if (strcmp(res->name, "result_embed") == 0) {
+    } else if (strcmp(res->name, "result_embd") == 0) {
         embeddings = res;
         res = nullptr;
     } else {
@@ -7615,11 +7833,12 @@ static int llama_decode_internal(
     if (!lctx.embedding.empty()) {
         auto & embedding_out = lctx.embedding;
 
-        const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0;
+        const int64_t embd_pos  = res ? n_embd * (n_tokens-1) : 0;
+        const int64_t embd_size = res ? n_embd : n_embd * n_tokens;
 
-        embedding_out.resize(n_embd);
+        embedding_out.resize(embd_size);
         ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
-        ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), n_embd*sizeof(float));
+        ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float));
         ggml_backend_synchronize(embeddings_backend);
     }
 
@@ -7696,7 +7915,13 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
     switch (llama_vocab_get_type(vocab)) {
         case LLAMA_VOCAB_TYPE_SPM: {
             const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
-            return vocab.token_to_id.at(buf);
+            auto token = vocab.token_to_id.find(buf);
+            if (token != vocab.token_to_id.end()) {
+                return (*token).second;
+            }
+            // Try to fall back to just the byte as a string
+            const char buf2[2] = { (char)ch, 0 };
+            return vocab.token_to_id.at(buf2);
         }
         case LLAMA_VOCAB_TYPE_WPM:
         case LLAMA_VOCAB_TYPE_BPE: {
@@ -7744,7 +7969,7 @@ struct llm_bigram_spm {
 };
 
 struct llm_tokenizer_spm {
-    llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
+    llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
 
     void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
         // split string into utf8 chars
@@ -7819,6 +8044,7 @@ private:
 
         if (p == rev_merge.end()) {
             // output any symbols that did not form tokens as bytes.
+            output.reserve(output.size() + symbol.n);
             for (int j = 0; j < (int)symbol.n; ++j) {
                 llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]);
                 output.push_back(token_id);
@@ -8381,17 +8607,18 @@ struct fragment_buffer_variant {
         token(_token),
         raw_text(_dummy),
         offset(0),
-        length(0){}
+        length(0) {}
+
     fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
     :
         type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
-        token((llama_vocab::id)-1),
+        token((llama_vocab::id) - 1),
         raw_text(_raw_text),
         offset(_offset),
         length(_length){
-            GGML_ASSERT( _offset >= 0 );
-            GGML_ASSERT( _length >= 1 );
-            GGML_ASSERT( offset + length <= raw_text.length() );
+            GGML_ASSERT(_offset >= 0);
+            GGML_ASSERT(_length >= 1);
+            GGML_ASSERT(offset + length <= raw_text.length());
         }
 
     const FRAGMENT_BUFFER_VARIANT_TYPE type;
@@ -8515,14 +8742,14 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
     }
 
     std::forward_list<fragment_buffer_variant> fragment_buffer;
-    fragment_buffer.emplace_front( raw_text, 0, raw_text.length() );
+    fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
 
-    if (special) tokenizer_st_partition( vocab, fragment_buffer );
+    if (special) tokenizer_st_partition(vocab, fragment_buffer);
 
     switch (vocab.type) {
         case LLAMA_VOCAB_TYPE_SPM:
             {
-                for (const auto & fragment: fragment_buffer) {
+                for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         // without adding this leading whitespace, we do not get the same results as the original tokenizer
 
@@ -8550,7 +8777,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
             } break;
         case LLAMA_VOCAB_TYPE_BPE:
             {
-                for (const auto & fragment: fragment_buffer) {
+                for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
 
@@ -8566,7 +8793,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
             } break;
         case LLAMA_VOCAB_TYPE_WPM:
             {
-                for (const auto & fragment: fragment_buffer) {
+                for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
 
@@ -10087,20 +10314,20 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
         if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
             new_type = GGML_TYPE_Q8_0;
         }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS) {
+        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) {
             new_type = GGML_TYPE_Q5_K;
         }
         else if (new_type != GGML_TYPE_Q8_0) {
             new_type = GGML_TYPE_Q6_K;
         }
     } else if (name == "token_embd.weight") {
-        if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS) {
+        if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) {
             new_type = GGML_TYPE_Q2_K;
         }
         else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
             new_type = GGML_TYPE_Q4_K;
         }
-    } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS) {
+    } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) {
         if (name.find("attn_v.weight") != std::string::npos) {
             if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
             else new_type = GGML_TYPE_Q2_K;
@@ -10110,6 +10337,9 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
             if (qs.i_ffn_down < qs.n_ffn_down/8) new_type = GGML_TYPE_Q2_K;
             ++qs.i_ffn_down;
         }
+        else if (name.find("attn_output.weight") != std::string::npos) {
+            if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) new_type = GGML_TYPE_IQ2_XXS;
+        }
     } else if (name.find("attn_v.weight") != std::string::npos) {
         if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
             new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
@@ -10227,6 +10457,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
         }
         ++qs.i_ffn_up;
     }
+
     //    if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
     //}
     // IK: let's remove this, else Q2_K is almost the same as Q3_K_S
@@ -10242,7 +10473,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
     if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
         new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K ||
         new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS ||
-        new_type == GGML_TYPE_IQ3_XXS) {
+        new_type == GGML_TYPE_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) {
         int nx = tensor->ne[0];
         int ny = tensor->ne[1];
         if (nx % QK_K != 0) {
@@ -10257,6 +10488,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
             case GGML_TYPE_IQ2_XXS:
             case GGML_TYPE_IQ2_XS:
             case GGML_TYPE_IQ3_XXS:
+            case GGML_TYPE_IQ1_S:
             case GGML_TYPE_Q2_K: new_type = GGML_TYPE_Q4_0; break;
             case GGML_TYPE_Q3_K: new_type = GGML_TYPE_Q4_1; break;
             case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
@@ -10286,19 +10518,20 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
         // K-quants
         case LLAMA_FTYPE_MOSTLY_Q2_K_S:
-        case LLAMA_FTYPE_MOSTLY_Q2_K:   quantized_type = GGML_TYPE_Q2_K; break;
+        case LLAMA_FTYPE_MOSTLY_Q2_K:    quantized_type = GGML_TYPE_Q2_K;    break;
         case LLAMA_FTYPE_MOSTLY_Q3_K_XS:
         case LLAMA_FTYPE_MOSTLY_Q3_K_S:
         case LLAMA_FTYPE_MOSTLY_Q3_K_M:
-        case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break;
+        case LLAMA_FTYPE_MOSTLY_Q3_K_L:  quantized_type = GGML_TYPE_Q3_K;    break;
         case LLAMA_FTYPE_MOSTLY_Q4_K_S:
-        case LLAMA_FTYPE_MOSTLY_Q4_K_M: quantized_type = GGML_TYPE_Q4_K; break;
+        case LLAMA_FTYPE_MOSTLY_Q4_K_M:  quantized_type = GGML_TYPE_Q4_K;    break;
         case LLAMA_FTYPE_MOSTLY_Q5_K_S:
-        case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break;
-        case LLAMA_FTYPE_MOSTLY_Q6_K:   quantized_type = GGML_TYPE_Q6_K; break;
-        case LLAMA_FTYPE_MOSTLY_IQ2_XXS:quantized_type = GGML_TYPE_IQ2_XXS; break;
-        case LLAMA_FTYPE_MOSTLY_IQ2_XS :quantized_type = GGML_TYPE_IQ2_XS;  break;
-        case LLAMA_FTYPE_MOSTLY_IQ3_XXS:quantized_type = GGML_TYPE_IQ3_XXS; break;
+        case LLAMA_FTYPE_MOSTLY_Q5_K_M:  quantized_type = GGML_TYPE_Q5_K;    break;
+        case LLAMA_FTYPE_MOSTLY_Q6_K:    quantized_type = GGML_TYPE_Q6_K;    break;
+        case LLAMA_FTYPE_MOSTLY_IQ2_XXS: quantized_type = GGML_TYPE_IQ2_XXS; break;
+        case LLAMA_FTYPE_MOSTLY_IQ2_XS:  quantized_type = GGML_TYPE_IQ2_XS;  break;
+        case LLAMA_FTYPE_MOSTLY_IQ3_XXS: quantized_type = GGML_TYPE_IQ3_XXS; break;
+        case LLAMA_FTYPE_MOSTLY_IQ1_S:   quantized_type = GGML_TYPE_IQ1_S  ; break;
 
         default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
     }
@@ -10428,7 +10661,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         quantize &= !params->only_copy;
 
         // do not quantize expert gating tensors
-        quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
+        quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_FFN_GATE_INP, "weight");
+
+        // do not quantize positional embeddings and token types (BERT)
+        quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD,    "weight");
+        quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
 
         enum ggml_type new_type;
         void * new_data;
@@ -10468,6 +10705,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             }
             if ((new_type == GGML_TYPE_IQ2_XXS ||
                  new_type == GGML_TYPE_IQ2_XS  ||
+                 new_type == GGML_TYPE_IQ1_S   ||
                 (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
                 LLAMA_LOG_ERROR("\n\n============================================================\n");
                 LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
@@ -10702,7 +10940,7 @@ static int llama_apply_lora_from_file_internal(
                     {
                         LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n",
                                 __func__, ftype);
-                        return false;
+                        return 1;
                     }
         }
 
@@ -10930,6 +11168,7 @@ struct llama_context_params llama_context_default_params() {
         /*.logits_all                  =*/ false,
         /*.embedding                   =*/ false,
         /*.offload_kqv                 =*/ true,
+        /*.do_pooling                  =*/ true,
     };
 
     return result;
@@ -10990,7 +11229,7 @@ bool llama_mlock_supported(void) {
     return llama_supports_mlock();
 }
 
-void llama_backend_init(bool numa) {
+void llama_backend_init(void) {
     ggml_time_init();
 
     // needed to initialize f16 tables
@@ -11000,15 +11239,17 @@ void llama_backend_init(bool numa) {
         ggml_free(ctx);
     }
 
-    if (numa) {
-        ggml_numa_init();
-    }
-
 #ifdef GGML_USE_MPI
     ggml_mpi_backend_init();
 #endif
 }
 
+void llama_numa_init(enum ggml_numa_strategy numa) {
+    if (numa != GGML_NUMA_STRATEGY_DISABLED) {
+        ggml_numa_init(numa);
+    }
+}
+
 void llama_backend_free(void) {
 #ifdef GGML_USE_MPI
     ggml_mpi_backend_free();
@@ -11085,6 +11326,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.yarn_beta_slow   = params.yarn_beta_slow;
     cparams.mul_mat_q        = params.mul_mat_q;
     cparams.offload_kqv      = params.offload_kqv;
+    cparams.do_pooling       = params.do_pooling;
 
     cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
     cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
@@ -11232,14 +11474,14 @@ struct llama_context * llama_new_context_with_model(
         // resized during inference, reserve maximum
         ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
 
-        if (params.embedding){
+        if (params.embedding) {
             ctx->embedding.resize(hparams.n_embd);
         }
 
         // graph inputs
         {
             ggml_init_params init_params = {
-                /* .mem_size   */ ggml_tensor_overhead()*7,
+                /* .mem_size   */ ggml_tensor_overhead()*8,
                 /* .mem_buffer */ nullptr,
                 /* .no_alloc   */ true,
             };
@@ -11249,15 +11491,19 @@ struct llama_context * llama_new_context_with_model(
             ctx->inp_embd    = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
             ctx->inp_pos     = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
             ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
+            ctx->inp_KQ_pos  = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
             ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
-            ctx->inp_sum     = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, 1, cparams.n_batch);
+            ctx->inp_mean    = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
+            ctx->inp_cls     = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
 
             ggml_set_name(ctx->inp_tokens,  "inp_tokens");
             ggml_set_name(ctx->inp_embd,    "inp_embd");
             ggml_set_name(ctx->inp_pos,     "inp_pos");
             ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask");
+            ggml_set_name(ctx->inp_KQ_pos,  "inp_KQ_pos");
             ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
-            ggml_set_name(ctx->inp_sum,     "inp_sum");
+            ggml_set_name(ctx->inp_mean,    "inp_mean");
+            ggml_set_name(ctx->inp_cls,     "inp_cls");
 
             ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
 
@@ -12108,6 +12354,10 @@ float * llama_get_embeddings(struct llama_context * ctx) {
     return ctx->embedding.data();
 }
 
+float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
+    return ctx->embedding.data() + i*ctx->model.hparams.n_embd;
+}
+
 const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
     return model->vocab.id_to_token[token].text.c_str();
 }
@@ -12258,6 +12508,123 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
     return 0;
 }
 
+// trim whitespace from the beginning and end of a string
+static std::string trim(const std::string & str) {
+    size_t start = 0;
+    size_t end = str.size();
+    while (start < end && isspace(str[start])) {
+        start += 1;
+    }
+    while (end > start && isspace(str[end - 1])) {
+        end -= 1;
+    }
+    return str.substr(start, end - start);
+}
+
+// Simple version of "llama_apply_chat_template" that only works with strings
+// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
+static int32_t llama_chat_apply_template_internal(
+    const std::string & tmpl,
+    const std::vector<const llama_chat_message *> & chat,
+    std::string & dest, bool add_ass) {
+    // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
+    std::stringstream ss;
+    if (tmpl.find("<|im_start|>") != std::string::npos) {
+        // chatml template
+        for (auto message : chat) {
+            ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
+        }
+        if (add_ass) {
+            ss << "<|im_start|>assistant\n";
+        }
+    } else if (tmpl.find("[INST]") != std::string::npos) {
+        // llama2 template and its variants
+        // [variant] support system message
+        bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
+        // [variant] space before + after response
+        bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
+        // [variant] add BOS inside history
+        bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
+        // [variant] trim spaces from the input message
+        bool strip_message = tmpl.find("content.strip()") != std::string::npos;
+        // construct the prompt
+        bool is_inside_turn = true; // skip BOS at the beginning
+        ss << "[INST] ";
+        for (auto message : chat) {
+            std::string content = strip_message ? trim(message->content) : message->content;
+            std::string role(message->role);
+            if (!is_inside_turn) {
+                is_inside_turn = true;
+                ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
+            }
+            if (role == "system") {
+                if (support_system_message) {
+                    ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
+                } else {
+                    // if the model does not support system message, we still include it in the first message, but without <<SYS>>
+                    ss << content << "\n";
+                }
+            } else if (role == "user") {
+                ss << content << " [/INST]";
+            } else {
+                ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>";
+                is_inside_turn = false;
+            }
+        }
+        // llama2 templates seem to not care about "add_generation_prompt"
+    } else if (tmpl.find("<|user|>") != std::string::npos) {
+        // zephyr template
+        for (auto message : chat) {
+            ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
+        }
+        if (add_ass) {
+            ss << "<|assistant|>\n";
+        }
+    } else {
+        // template not supported
+        return -1;
+    }
+    dest = ss.str();
+    return dest.size();
+}
+
+LLAMA_API int32_t llama_chat_apply_template(
+                const struct llama_model * model,
+                              const char * tmpl,
+         const struct llama_chat_message * chat,
+                                  size_t   n_msg,
+                                    bool   add_ass,
+                                    char * buf,
+                                 int32_t   length) {
+    std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
+    if (tmpl == nullptr) {
+        GGML_ASSERT(model != nullptr);
+        // load template from model
+        std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
+        std::string template_key = "tokenizer.chat_template";
+        int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size());
+        if (res < 0) {
+            // worst case: there is no information about template, we will use chatml by default
+            curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal
+        } else {
+            curr_tmpl = std::string(model_template.data(), model_template.size());
+        }
+    }
+    // format the chat to string
+    std::vector<const llama_chat_message *> chat_vec;
+    chat_vec.resize(n_msg);
+    for (size_t i = 0; i < n_msg; i++) {
+        chat_vec[i] = &chat[i];
+    }
+    std::string formatted_chat;
+    int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
+    if (res < 0) {
+        return res;
+    }
+    strncpy(buf, formatted_chat.c_str(), length);
+    return res;
+}
+
 struct llama_timings llama_get_timings(struct llama_context * ctx) {
     struct llama_timings result = {
         /*.t_start_ms  =*/ 1e-3 * ctx->t_start_us,
index 367e8f1a105a5f8cc9f7d8d046f29a467e6262a5..77a84c18a69cfbc56841683ac0de3e9e7cc9d5f7 100644 (file)
@@ -100,6 +100,7 @@ extern "C" {
         LLAMA_FTYPE_MOSTLY_Q2_K_S        = 21, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q3_K_XS       = 22, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_IQ3_XXS       = 23, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_IQ1_S         = 24, // except 1d tensors
 
         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
     };
@@ -112,6 +113,12 @@ extern "C" {
         LLAMA_ROPE_SCALING_MAX_VALUE   = LLAMA_ROPE_SCALING_YARN,
     };
 
+    enum llama_pooling_type {
+        LLAMA_POOLING_NONE = 0,
+        LLAMA_POOLING_MEAN = 1,
+        LLAMA_POOLING_CLS  = 2,
+    };
+
     enum llama_split_mode {
         LLAMA_SPLIT_NONE    = 0, // single GPU
         LLAMA_SPLIT_LAYER   = 1, // split layers and KV across GPUs
@@ -236,6 +243,7 @@ extern "C" {
         bool logits_all;  // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
         bool embedding;   // embedding mode only
         bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
+        bool do_pooling;  // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
     };
 
     // model quantization parameters
@@ -297,6 +305,12 @@ extern "C" {
         int32_t n_eval;
     };
 
+    // used in chat template
+    typedef struct llama_chat_message {
+        const char * role;
+        const char * content;
+    } llama_chat_message;
+
     // Helpers for getting default parameters
     LLAMA_API struct llama_model_params llama_model_default_params(void);
     LLAMA_API struct llama_context_params llama_context_default_params(void);
@@ -305,7 +319,10 @@ extern "C" {
     // Initialize the llama + ggml backend
     // If numa is true, use NUMA optimizations
     // Call once at the start of the program
-    LLAMA_API void llama_backend_init(bool numa);
+    LLAMA_API void llama_backend_init(void);
+
+    //optional:
+    LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
 
     // Call once at the end of the program - currently only used for MPI
     LLAMA_API void llama_backend_free(void);
@@ -628,6 +645,10 @@ extern "C" {
     // shape: [n_embd] (1-dimensional)
     LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
 
+    // Get the embeddings for the ith sequence
+    // llama_get_embeddings(ctx) + i*n_embd
+    LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
+
     //
     // Vocab
     //
@@ -684,6 +705,25 @@ extern "C" {
                                   char * buf,
                                int32_t   length);
 
+    /// Apply chat template. Inspired by hf apply_chat_template() on python.
+    /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
+    /// NOTE: This function only support some known jinja templates. It is not a jinja parser.
+    /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
+    /// @param chat Pointer to a list of multiple llama_chat_message
+    /// @param n_msg Number of llama_chat_message in this chat
+    /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.
+    /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)
+    /// @param length The size of the allocated buffer
+    /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
+    LLAMA_API int32_t llama_chat_apply_template(
+              const struct llama_model * model,
+                            const char * tmpl,
+       const struct llama_chat_message * chat,
+                                size_t   n_msg,
+                                  bool   add_ass,
+                                  char * buf,
+                               int32_t   length);
+
     //
     // Grammar
     //
index 9f18a39a2cce18f4e5e29d6ab913aae722d69bab..60dd99e5d692c91e44cb053a608bdff8e0cda75d 100644 (file)
@@ -288,7 +288,7 @@ int main(int argc, char ** argv) {
 
     // llama init
 
-    llama_backend_init(true);
+    llama_backend_init();
 
     auto lmparams = llama_model_default_params();
     if (!params.use_gpu) {
index 844eff3dad1b3fde6223b18aedd6655e3238e76f..263260702e640a2062ba0d0bcfae9aa6d236536a 100644 (file)
@@ -264,26 +264,29 @@ static uint32_t codepoint_from_utf8(const std::string & utf8, size_t & offset) {
         offset += 1;
         return result;
     }
-    else if (!(utf8[offset + 0] & 0x40)) {
+    if (!(utf8[offset + 0] & 0x40)) {
         throw std::invalid_argument("invalid character");
     }
-    else if (!(utf8[offset + 0] & 0x20)) {
-        if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80))
+    if (!(utf8[offset + 0] & 0x20)) {
+        if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) {
             throw std::invalid_argument("invalid character");
+        }
         auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
         offset += 2;
         return result;
     }
-    else if (!(utf8[offset + 0] & 0x10)) {
-        if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80))
+    if (!(utf8[offset + 0] & 0x10)) {
+        if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) {
             throw std::invalid_argument("invalid character");
+        }
         auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
         offset += 3;
         return result;
     }
-    else if (!(utf8[offset + 0] & 0x08)) {
-        if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80))
+    if (!(utf8[offset + 0] & 0x08)) {
+        if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
             throw std::invalid_argument("invalid character");
+        }
         auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
         offset += 4;
         return result;
@@ -331,21 +334,22 @@ static uint32_t codepoint_from_utf16(const std::vector<uint16_t> & utf16, size_t
         offset += 1;
         return result;
     }
-    else {
-        if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00))
-            throw std::invalid_argument("invalid character");
-        auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
-        offset += 2;
-        return result;
+
+    if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
+        throw std::invalid_argument("invalid character");
     }
-    throw std::invalid_argument("invalid string");
+
+    auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
+    offset += 2;
+    return result;
 }
 
 static std::vector<uint32_t> codepoints_from_utf16(const std::vector<uint16_t> & utf16) {
     std::vector<uint32_t> result;
     size_t offset = 0;
-    while (offset < utf16.size())
+    while (offset < utf16.size()) {
         result.push_back(codepoint_from_utf16(utf16, offset));
+    }
     return result;
 }
 
@@ -361,44 +365,52 @@ static std::vector<uint32_t> codepoints_from_utf16(const std::vector<uint16_t> &
 static std::unordered_map<uint32_t, int> codepoint_type_map() {
     std::unordered_map<uint32_t, int> codepoint_types;
     for (auto p : digit_ranges) {
-        for(auto i = p.first; i <= p.second; ++ i)
+        for (auto i = p.first; i <= p.second; ++ i) {
             codepoint_types[i] = CODEPOINT_TYPE_DIGIT;
+        }
     }
-    for(auto p : letter_ranges) {
-        for(auto i = p.first; i <= p.second; ++ i)
+    for (auto p : letter_ranges) {
+        for (auto i = p.first; i <= p.second; ++ i) {
             codepoint_types[i] = CODEPOINT_TYPE_LETTER;
+        }
     }
-    for(auto p : whitespace_ranges) {
-        for(auto i = p.first; i <= p.second; ++ i)
+    for (auto p : whitespace_ranges) {
+        for (auto i = p.first; i <= p.second; ++ i) {
             codepoint_types[i] = CODEPOINT_TYPE_WHITESPACE;
+        }
     }
-    for(auto p : accent_mark_ranges) {
-        for(auto i = p.first; i <= p.second; ++ i)
+    for (auto p : accent_mark_ranges) {
+        for (auto i = p.first; i <= p.second; ++ i) {
             codepoint_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
+        }
     }
-    for(auto p : punctuation_ranges) {
-        for(auto i = p.first; i <= p.second; ++ i)
+    for (auto p : punctuation_ranges) {
+        for (auto i = p.first; i <= p.second; ++ i) {
             codepoint_types[i] = CODEPOINT_TYPE_PUNCTUATION;
+        }
     }
-    for (auto p : symbol_ranges) {
-        for (auto i = p.first; i <= p.second; ++i)
+    for  (auto p : symbol_ranges) {
+        for (auto i = p.first; i <= p.second; ++i) {
             codepoint_types[i] = CODEPOINT_TYPE_SYMBOL;
+        }
     }
-    for(auto p : control_ranges) {
-        for(auto i = p.first; i <= p.second; ++ i)
+    for (auto p : control_ranges) {
+        for (auto i = p.first; i <= p.second; ++ i) {
             codepoint_types[i] = CODEPOINT_TYPE_CONTROL;
+        }
     }
     return codepoint_types;
 }
 
 static int codepoint_type(uint32_t cp) {
     static std::unordered_map<uint32_t, int> codepoint_types = codepoint_type_map();
-    return codepoint_types[cp];
+    return codepoint_types.find(cp) == codepoint_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : codepoint_types.at(cp);
 }
 
 static int codepoint_type(const std::string & utf8) {
-    if (utf8.length() == 0)
+    if (utf8.length() == 0) {
         return CODEPOINT_TYPE_UNIDENTIFIED;
+    }
     size_t offset = 0;
     return codepoint_type(codepoint_from_utf8(utf8, offset));
 }