if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901":
# ref: https://huggingface.co/core42/jais-13b
res = "jais"
+ if chkhsh == "bc5108ee1eb6a3d600cadd065f63190fbd0554dbc9e4bbd6a0d977970afc8d2a":
+ # ref: https://huggingface.co/inceptionai/Jais-2-8B-Chat
+ res = "jais-2"
if chkhsh == "7b3e7548e4308f52a76e8229e4e6cc831195d0d1df43aed21ac6c93da05fec5f":
# ref: https://huggingface.co/WisdomShell/CodeShell-7B
res = "codeshell"
yield from super().modify_tensors(data_torch, name, bid)
+@ModelBase.register("Jais2ForCausalLM")
+class Jais2Model(TextModel):
+ model_arch = gguf.MODEL_ARCH.JAIS2
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ hparams = self.hparams
+ head_dim = hparams.get("head_dim", hparams["hidden_size"] // hparams["num_attention_heads"])
+ self.gguf_writer.add_rope_dimension_count(head_dim)
+
+
@ModelBase.register("JAISLMHeadModel")
class JaisModel(TextModel):
model_arch = gguf.MODEL_ARCH.JAIS
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
+ {"name": "jais-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inceptionai/Jais-2-8B-Chat", },
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
T5 = auto()
T5ENCODER = auto()
JAIS = auto()
+ JAIS2 = auto()
NEMOTRON = auto()
NEMOTRON_H = auto()
NEMOTRON_H_MOE = auto()
MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais",
+ MODEL_ARCH.JAIS2: "jais2",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.NEMOTRON_H: "nemotron_h",
MODEL_ARCH.NEMOTRON_H_MOE: "nemotron_h_moe",
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_UP,
],
+ MODEL_ARCH.JAIS2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
MODEL_ARCH.NEMOTRON: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
models/hunyuan-moe.cpp
models/internlm2.cpp
models/jais.cpp
+ models/jais2.cpp
models/jamba.cpp
models/kimi-linear.cpp
models/lfm2.cpp
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
{ LLM_ARCH_JAIS, "jais" },
+ { LLM_ARCH_JAIS2, "jais2" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" },
{ LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" },
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
};
+ case LLM_ARCH_JAIS2:
+ return {
+ LLM_TENSOR_TOKEN_EMBD,
+ LLM_TENSOR_OUTPUT_NORM,
+ LLM_TENSOR_OUTPUT,
+ LLM_TENSOR_ATTN_NORM,
+ LLM_TENSOR_ATTN_Q,
+ LLM_TENSOR_ATTN_K,
+ LLM_TENSOR_ATTN_V,
+ LLM_TENSOR_ATTN_OUT,
+ LLM_TENSOR_FFN_NORM,
+ LLM_TENSOR_FFN_UP,
+ LLM_TENSOR_FFN_DOWN,
+ };
case LLM_ARCH_NEMOTRON_H:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
LLM_ARCH_JAIS,
+ LLM_ARCH_JAIS2,
LLM_ARCH_NEMOTRON,
LLM_ARCH_NEMOTRON_H,
LLM_ARCH_NEMOTRON_H_MOE,
if (down) {
cur = build_lora_mm(down, cur);
- if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
- // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
+ // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
ggml_tensor * cur;
- if (cparams.flash_attn && kq_b == nullptr) {
+ const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr;
+ if (use_flash_attn) {
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
if (v_trans) {
if (wo) {
cur = build_lora_mm(wo, cur);
- if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
- // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
+ // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_JAIS2:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+
+ switch (hparams.n_layer) {
+ case 32: type = LLM_TYPE_8B; break;
+ case 68: type = LLM_TYPE_70B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_NEMOTRON:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
}
} break;
+ case LLM_ARCH_JAIS2:
+ {
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+ if (!output) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+ layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
+
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+ // attention biases - all have shape n_embd (output dimension of projections)
+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0);
+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0);
+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+ layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
+
+ // Jais-2 uses simple MLP (no gate) with biases
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+ layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
+ }
+ } break;
case LLM_ARCH_CHATGLM:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
{
llm = std::make_unique<llm_build_jais>(*this, params);
} break;
+ case LLM_ARCH_JAIS2:
+ {
+ llm = std::make_unique<llm_build_jais2>(*this, params);
+ } break;
case LLM_ARCH_NEMOTRON:
{
llm = std::make_unique<llm_build_nemotron>(*this, params);
case LLM_ARCH_BAILINGMOE2:
case LLM_ARCH_DOTS1:
case LLM_ARCH_HUNYUAN_MOE:
+ case LLM_ARCH_JAIS2:
case LLM_ARCH_OPENAI_MOE:
case LLM_ARCH_HUNYUAN_DENSE:
case LLM_ARCH_LFM2:
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
+ case LLAMA_VOCAB_PRE_TYPE_JAIS2:
+ regex_exprs = {
+ // original regex from tokenizer.json
+ //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}",
+
+ // adapted: same as llama3 but with cascading whitespace pattern
+ "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}",
+ };
+ break;
case LLAMA_VOCAB_PRE_TYPE_DBRX:
case LLAMA_VOCAB_PRE_TYPE_SMAUG:
regex_exprs = {
tokenizer_pre == "jina-v2-de" ||
tokenizer_pre == "a.x-4.0" ||
tokenizer_pre == "mellum" ||
- tokenizer_pre == "modern-bert" ) {
+ tokenizer_pre == "modern-bert") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
+ } else if (
+ tokenizer_pre == "jais-2") {
+ pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2;
} else if (
tokenizer_pre == "jina-v1-en" ||
tokenizer_pre == "jina-v2-code" ||
LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46,
LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47,
LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48,
+ LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49,
};
struct LLM_KV;
--- /dev/null
+#include "models.h"
+
+// JAIS-2 model graph builder
+// Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings
+llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ // KV input for attention
+ auto * inp_attn = build_attn_inp_kv();
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ for (int il = 0; il < n_layer; ++il) {
+ // Pre-attention LayerNorm
+ cur = build_norm(inpL,
+ model.layers[il].attn_norm,
+ model.layers[il].attn_norm_b,
+ LLM_NORM, il);
+ cb(cur, "attn_norm", il);
+
+ // Self-attention with separate Q, K, V projections
+ {
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur_bias", il);
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur_bias", il);
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur_bias", il);
+
+ // Reshape for attention
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ // Apply RoPE
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ cb(Qcur, "Qcur_rope", il);
+ cb(Kcur, "Kcur_rope", il);
+
+ cur = build_attn(inp_attn,
+ model.layers[il].wo, model.layers[il].bo,
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+ }
+
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
+ // Residual connection
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // Pre-FFN LayerNorm
+ cur = build_norm(ffn_inp,
+ model.layers[il].ffn_norm,
+ model.layers[il].ffn_norm_b,
+ LLM_NORM, il);
+ cb(cur, "ffn_norm", il);
+
+ // FFN with relu2 activation (ReLU squared) - no gate projection
+ // up -> relu2 -> down
+ cur = build_ffn(cur,
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
+ NULL, NULL, NULL, // no gate
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+ NULL,
+ LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
+ cb(cur, "ffn_out", il);
+
+ // Residual connection
+ inpL = ggml_add(ctx0, cur, ffn_inp);
+ inpL = build_cvec(inpL, il);
+ cb(inpL, "l_out", il);
+ }
+
+ // Final LayerNorm
+ cur = build_norm(inpL,
+ model.output_norm,
+ model.output_norm_b,
+ LLM_NORM, -1);
+ cb(cur, "result_norm", -1);
+
+ res->t_embd = cur;
+
+ // Output projection
+ cur = build_lora_mm(model.output, cur);
+ cb(cur, "result_output", -1);
+
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+}
llm_build_jais(const llama_model & model, const llm_graph_params & params);
};
+struct llm_build_jais2 : public llm_graph_context {
+ llm_build_jais2(const llama_model & model, const llm_graph_params & params);
+};
+
struct llm_build_jamba : public llm_build_mamba_base {
llm_build_jamba(const llama_model & model, const llm_graph_params & params);
};