]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : support for `falcon-mamba` architecture (#9074)
authorYounes Belkada <redacted>
Wed, 21 Aug 2024 08:06:36 +0000 (12:06 +0400)
committerGitHub <redacted>
Wed, 21 Aug 2024 08:06:36 +0000 (11:06 +0300)
* feat: initial support for llama.cpp

* fix: lint

* refactor: better refactor

* Update src/llama.cpp

Co-authored-by: compilade <redacted>
* Update src/llama.cpp

Co-authored-by: compilade <redacted>
* fix: address comments

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <redacted>
* fix: add more cleanup and harmonization

* fix: lint

* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: compilade <redacted>
* fix: change name

* Apply suggestions from code review

Co-authored-by: compilade <redacted>
* add in operator

* fix: add `dt_b_c_rms` in `llm_load_print_meta`

* fix: correct printf format for bool

* fix: correct print format

* Update src/llama.cpp

Co-authored-by: compilade <redacted>
* llama : quantize more Mamba tensors

* llama : use f16 as the fallback of fallback quant types

---------

Co-authored-by: compilade <redacted>
README.md
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
src/llama.cpp

index 04d315db7144ed2f598e04cde44191ce82b2281f..bb2b93a35021fc00db18b77b19db942b55430bb7 100644 (file)
--- a/README.md
+++ b/README.md
@@ -106,6 +106,7 @@ Typically finetunes of the base models below are supported as well.
 - [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
 - [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
 - [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
+- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
 
 (instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
 
index 6a1a3a937febd611506af9b6e02b0abfa1272c0e..108c822cff5d22fb2c0430b1b1ae279db680902b 100755 (executable)
@@ -295,6 +295,7 @@ class Model:
                             gguf.MODEL_TENSOR.FFN_GATE_INP,
                             gguf.MODEL_TENSOR.POS_EMBD,
                             gguf.MODEL_TENSOR.TOKEN_TYPES,
+                            gguf.MODEL_TENSOR.SSM_CONV1D,
                         )
                     )
                     or not name.endswith(".weight")
@@ -2711,7 +2712,7 @@ class StarCoder2Model(Model):
     model_arch = gguf.MODEL_ARCH.STARCODER2
 
 
-@Model.register("MambaForCausalLM", "MambaLMHeadModel")
+@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
 class MambaModel(Model):
     model_arch = gguf.MODEL_ARCH.MAMBA
 
@@ -2742,7 +2743,10 @@ class MambaModel(Model):
         # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
         dt_rank      = self.find_hparam(["time_step_rank",     "dt_rank"],      optional=True) or -(d_model // -16)
         rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
-
+        use_dt_b_c_norm = False
+        # For falconmamba we do apply RMS norm on B / DT and C layers
+        if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",):
+            use_dt_b_c_norm = True
         # Fail early for models which don't have a block expansion factor of 2
         assert d_inner == 2 * d_model
 
@@ -2750,12 +2754,13 @@ class MambaModel(Model):
         self.gguf_writer.add_embedding_length(d_model)
         self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
         self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
-        self.gguf_writer.add_block_count(self.hparams["n_layer"])
+        self.gguf_writer.add_block_count(self.block_count)
         self.gguf_writer.add_ssm_conv_kernel(d_conv)
         self.gguf_writer.add_ssm_inner_size(d_inner)
         self.gguf_writer.add_ssm_state_size(d_state)
         self.gguf_writer.add_ssm_time_step_rank(dt_rank)
         self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
+        self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers
         self.gguf_writer.add_file_type(self.ftype)
 
     _tok_embd = None
@@ -2782,23 +2787,6 @@ class MambaModel(Model):
 
         return [(new_name, data_torch)]
 
-    def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
-        if bid is not None and new_name in (
-            self.format_tensor_name(
-                n, bid, ".weight" if name.endswith(".weight") else ""
-            )
-            for n in [
-                gguf.MODEL_TENSOR.SSM_CONV1D,
-                gguf.MODEL_TENSOR.SSM_X,
-                gguf.MODEL_TENSOR.SSM_DT,
-                gguf.MODEL_TENSOR.SSM_A,
-                gguf.MODEL_TENSOR.SSM_D,
-            ]
-        ):
-            return gguf.GGMLQuantizationType.F32
-
-        return super().tensor_force_quant(name, new_name, bid, n_dims)
-
 
 @Model.register("CohereForCausalLM")
 class CommandR2Model(Model):
@@ -3792,7 +3780,7 @@ class ExaoneModel(Model):
     def set_gguf_parameters(self):
         hparams = self.hparams
 
-        assert(hparams["activation_function"] == "silu")
+        assert (hparams["activation_function"] == "silu")
 
         max_position_embeddings = hparams["max_position_embeddings"]
         embed_dim = hparams["hidden_size"]
@@ -3855,8 +3843,8 @@ class ExaoneModel(Model):
 
         super().prepare_tensors()
 
-###### CONVERSION LOGIC ######
 
+###### CONVERSION LOGIC ######
 
 # tree of lazy tensors
 class LazyTorchTensor(gguf.LazyBase):
index 5541972ce52b0dff64a4597b6fadc82809ee4408..b55effa9907b100106cc1fe5a88e2abbb7dd505d 100644 (file)
@@ -130,6 +130,7 @@ class Keys:
         INNER_SIZE     = "{arch}.ssm.inner_size"
         STATE_SIZE     = "{arch}.ssm.state_size"
         TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
+        DT_B_C_RMS     = "{arch}.ssm.dt_b_c_rms"
 
     class Tokenizer:
         MODEL                = "tokenizer.ggml.model"
@@ -1372,6 +1373,7 @@ KEY_SSM_CONV_KERNEL    = Keys.SSM.CONV_KERNEL
 KEY_SSM_INNER_SIZE     = Keys.SSM.INNER_SIZE
 KEY_SSM_STATE_SIZE     = Keys.SSM.STATE_SIZE
 KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
+KEY_SSM_DT_B_C_RMS     = Keys.SSM.DT_B_C_RMS
 
 # tokenization
 KEY_TOKENIZER_MODEL      = Keys.Tokenizer.MODEL
index 76385a82872c941a6b602dda151dd38cdb15e766..af3b98c679b0b66cd36d0a1ab5dafb8262936a0f 100644 (file)
@@ -730,6 +730,9 @@ class GGUFWriter:
     def add_ssm_time_step_rank(self, value: int) -> None:
         self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
 
+    def add_ssm_dt_b_c_rms(self, value: bool) -> None:
+        self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
+
     def add_tokenizer_model(self, model: str) -> None:
         self.add_string(Keys.Tokenizer.MODEL, model)
 
index 5ab65ea97defadb106f3f4a7d55eabde9f14d033..fe3c0db6f2931d9e1c3082ab6aa098f97ea36758 100644 (file)
@@ -328,6 +328,7 @@ enum llm_kv {
     LLM_KV_SSM_CONV_KERNEL,
     LLM_KV_SSM_STATE_SIZE,
     LLM_KV_SSM_TIME_STEP_RANK,
+    LLM_KV_SSM_DT_B_C_RMS,
 
     LLM_KV_TOKENIZER_MODEL,
     LLM_KV_TOKENIZER_PRE,
@@ -426,6 +427,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_SSM_INNER_SIZE,                "%s.ssm.inner_size"     },
     { LLM_KV_SSM_STATE_SIZE,                "%s.ssm.state_size"     },
     { LLM_KV_SSM_TIME_STEP_RANK,            "%s.ssm.time_step_rank" },
+    { LLM_KV_SSM_DT_B_C_RMS,                "%s.ssm.dt_b_c_rms" },
 
     { LLM_KV_TOKENIZER_MODEL,                "tokenizer.ggml.model"                    },
     { LLM_KV_TOKENIZER_PRE,                  "tokenizer.ggml.pre"                      },
@@ -2237,6 +2239,7 @@ struct llama_hparams {
     uint32_t ssm_d_inner = 0;
     uint32_t ssm_d_state = 0;
     uint32_t ssm_dt_rank = 0;
+    bool ssm_dt_b_c_rms = false;
 
     float f_clamp_kqv      = 0.0f;
     float f_max_alibi_bias = 0.0f;
@@ -2286,6 +2289,7 @@ struct llama_hparams {
         if (this->ssm_d_inner != other.ssm_d_inner) return true;
         if (this->ssm_d_state != other.ssm_d_state) return true;
         if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
+        if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
 
         if (this->dec_start_token_id != other.dec_start_token_id) return true;
 
@@ -5052,6 +5056,7 @@ static void llm_load_hparams(
                 ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
                 ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
                 ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+                ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false);
 
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 
@@ -5907,6 +5912,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
         LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
         LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
+        LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
     }
 
     LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type));
@@ -12161,6 +12167,10 @@ struct llm_build_context {
         GGML_ASSERT(2 * d_model == d_inner);
         const int64_t d_state = hparams.ssm_d_state;
         const int64_t dt_rank = hparams.ssm_dt_rank;
+        // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
+        const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
+        // Use the same RMS norm as the final layer norm
+        const float norm_rms_eps = hparams.f_norm_rms_eps;
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -12241,6 +12251,13 @@ struct llm_build_context {
                 struct ggml_tensor * B  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
                 struct ggml_tensor * C  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
 
+                // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
+                if (ssm_dt_b_c_rms) {
+                    dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
+                    B = ggml_rms_norm(ctx0, B, norm_rms_eps);
+                    C = ggml_rms_norm(ctx0, C, norm_rms_eps);
+                }
+
                 // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
                 dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
                 dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
@@ -16105,6 +16122,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
             case GGML_TYPE_Q6_K:   new_type = GGML_TYPE_Q8_0;   break;
             default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
         }
+        if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
+            new_type = GGML_TYPE_F16;
+        }
         LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
         ++qs.n_fallback;
     }
@@ -16433,8 +16453,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         // do not quantize Mamba's small yet 2D weights
         // NOTE: can't use LLM_TN here because the layer number is not known
         quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
-        quantize &= name.find("ssm_x.weight")      == std::string::npos;
-        quantize &= name.find("ssm_dt.weight")     == std::string::npos;
 
         // do not quantize relative position bias (T5)
         quantize &= name.find("attn_rel_b.weight") == std::string::npos;