self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+@Model.register("OlmoForCausalLM")
+@Model.register("OLMoForCausalLM")
+class OlmoModel(Model):
+ model_arch = gguf.MODEL_ARCH.OLMO
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self.gguf_writer.add_layer_norm_eps(1e-5)
+ if "clip_qkv" in self.hparams is not None:
+ self.gguf_writer.add_clamp_kqv(self.hparams["clip_qkv"])
+
+ # Same as super class, but permuting q_proj, k_proj
+ # Copied from: LlamaModel
+ def write_tensors(self):
+ block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
+ tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+ n_head = self.hparams.get("num_attention_heads")
+ n_kv_head = self.hparams.get("num_key_value_heads")
+ for name, data_torch in self.get_tensors():
+ old_dtype = data_torch.dtype
+
+ # convert any unsupported data types to float32
+ if data_torch.dtype not in (torch.float16, torch.float32):
+ data_torch = data_torch.to(torch.float32)
+
+ data = data_torch.numpy()
+
+ if name.endswith("q_proj.weight"):
+ data = permute(data, n_head, n_head)
+ if name.endswith("k_proj.weight"):
+ data = permute(data, n_head, n_kv_head)
+
+ data = data.squeeze()
+
+ # map tensor names
+ new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+ if new_name is None:
+ print(f"Can not map tensor {name!r}")
+ sys.exit()
+
+ n_dims = len(data.shape)
+ data_dtype = data.dtype
+
+ # if f32 desired, convert any float16 to float32
+ if self.ftype == 0 and data_dtype == np.float16:
+ data = data.astype(np.float32)
+
+ # 1d tensors need to be converted to float32
+ if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+ data = data.astype(np.float32)
+
+ # if f16 desired, convert any float32 2-dim weight tensors to float16
+ if self.ftype == 1 and data_dtype == np.float32 and n_dims == 2:
+ data = data.astype(np.float16)
+
+ print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+
+ self.gguf_writer.add_tensor(new_name, data)
+
+
###### CONVERSION LOGIC ######
LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R,
LLM_ARCH_DBRX,
+ LLM_ARCH_OLMO,
LLM_ARCH_UNKNOWN,
};
{ LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_DBRX, "dbrx" },
+ { LLM_ARCH_OLMO, "olmo" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
+ {
+ LLM_ARCH_OLMO,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_UNKNOWN,
{
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_OLMO:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+ ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false);
+
+ switch (hparams.n_layer) {
+ case 22: model.type = e_model::MODEL_1B; break;
+ case 32: model.type = e_model::MODEL_7B; break;
+ case 80: model.type = e_model::MODEL_70B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
+ }
+ } break;
+ case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed
+ {
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ {
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
+ // if output is NULL, init from the input tok embed
+ if (model.output == NULL) {
+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+ ml.n_created--; // artificial tensor
+ ml.size_data += ggml_nbytes(model.output);
+ }
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
return gf;
}
+
+ // ref: https://allenai.org/olmo
+ // based on the original build_llama() function, changes:
+ // * non-parametric layer norm
+ // * clamp qkv
+ // * removed bias
+ // * removed MoE
+ struct ggml_cgraph * build_olmo() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ NULL, NULL,
+ LLM_NORM, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ if (hparams.f_clamp_kqv > 0.0f) {
+ Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+ cb(Qcur, "Qcur", il);
+ }
+
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ if (hparams.f_clamp_kqv > 0.0f) {
+ Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+ cb(Kcur, "Kcur", il);
+ }
+
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ if (hparams.f_clamp_kqv > 0.0f) {
+ Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+ cb(Vcur, "Vcur", il);
+ }
+
+ Qcur = ggml_rope_custom(
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_rope_custom(
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur", il);
+
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ model.layers[il].wo, nullptr,
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ NULL, NULL,
+ LLM_NORM, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, NULL,
+ model.layers[il].ffn_gate, NULL,
+ model.layers[il].ffn_down, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "ffn_out", il);
+
+ ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+ if (layer_dir != nullptr) {
+ cur = ggml_add(ctx0, cur, layer_dir);
+ }
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ NULL, NULL,
+ LLM_NORM, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
{
result = llm.build_dbrx();
} break;
+ case LLM_ARCH_OLMO:
+ {
+ result = llm.build_olmo();
+ } break;
default:
GGML_ASSERT(false);
}
case LLM_ARCH_MINICPM:
case LLM_ARCH_XVERSE:
case LLM_ARCH_COMMAND_R:
+ case LLM_ARCH_OLMO:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2