logger.warning(f"Failed to load model config from {dir_model}: {e}")
logger.warning("Trying to load config.json instead")
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
- return json.load(f)
+ config = json.load(f)
+ if "llm_config" in config:
+ # rename for InternVL
+ config["text_config"] = config["llm_config"]
+ return config
@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if self.hf_arch == "Qwen2Model":
name = f"model.{name}" # map to Qwen2ForCausalLM tensors
+ if "language_model." in name:
+ name = name.replace("language_model.", "") # for InternVL
+ if name.startswith("mlp") or name.startswith("vision_model"):
+ # skip visual tensors
+ return []
yield from super().modify_tensors(data_torch, name, bid)
return [] # skip other tensors
+@ModelBase.register("InternVisionModel")
+class InternVisionModel(VisionModel):
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ hparams = self.hparams
+ self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL)
+ self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
+ # hidden_act
+ if hparams["hidden_act"] == "silu":
+ self.gguf_writer.add_vision_use_silu(True)
+ elif hparams["hidden_act"] == "gelu":
+ self.gguf_writer.add_vision_use_gelu(True)
+ else:
+ raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}")
+ # downsample_ratio
+ downsample_ratio = self.global_config.get("downsample_ratio")
+ assert downsample_ratio is not None
+ self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
+
+ def tensor_force_quant(self, name, new_name, bid, n_dims):
+ del bid, name, n_dims # unused
+ if ".patch_embd." in new_name:
+ return gguf.GGMLQuantizationType.F16
+ if ".position_embd." in new_name:
+ return gguf.GGMLQuantizationType.F32
+ return False
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ del bid # unused
+ if name.startswith("vision_model") or name.startswith("mlp"):
+ # process visual tensors
+ # correct name
+ if name.startswith("vision_model"):
+ name = "vision_tower." + name
+ if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"):
+ name += ".weight"
+ # split QKV tensors if needed
+ if ".qkv." in name:
+ if data_torch.ndim == 2: # weight
+ c3, _ = data_torch.shape
+ else: # bias
+ c3 = data_torch.shape[0]
+ assert c3 % 3 == 0
+ c = c3 // 3
+ wq = data_torch[:c]
+ wk = data_torch[c: c * 2]
+ wv = data_torch[c * 2:]
+ return [
+ (self.map_tensor_name(name.replace("attn.qkv", "self_attn.q_proj")), wq),
+ (self.map_tensor_name(name.replace("attn.qkv", "self_attn.k_proj")), wk),
+ (self.map_tensor_name(name.replace("attn.qkv", "self_attn.v_proj")), wv),
+ ]
+ return [(self.map_tensor_name(name), data_torch)]
+ return [] # skip other tensors
+
+
@ModelBase.register("WavTokenizerDec")
class WavTokenizerDecModel(TextModel):
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
head_dim = n_embd // num_heads
num_groups = num_heads // q_per_kv
+ name = name.replace("language_model.", "") # InternVL
+ if name.startswith("mlp") or name.startswith("vision_model"):
+ # skip visual tensors
+ return []
+
if bid is not None and f"model.layers.{bid}.attention.wqkv" in name:
qkv = data_torch
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
+ name = name.replace("language_model.", "") # InternVL
+ if name.startswith("mlp") or name.startswith("vision_model"):
+ # skip visual tensors
+ return []
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
// layernorm 2
ggml_tensor * ln_2_w = nullptr;
ggml_tensor * ln_2_b = nullptr;
+
+ // layer scale (no bias)
+ ggml_tensor * ls_1_w = nullptr;
+ ggml_tensor * ls_2_w = nullptr;
};
struct clip_vision_model {
// Qwen2VL and Qwen2.5VL use M-RoPE
ggml_cgraph * build_qwen2vl() {
+ GGML_ASSERT(model.patch_bias == nullptr);
+ GGML_ASSERT(model.class_embedding == nullptr);
+
const int batch_size = 1;
const bool use_window_attn = hparams.n_wa_pattern > 0;
const int n_wa_pattern = hparams.n_wa_pattern;
n_embd, n_patches_x * n_patches_y, batch_size);
}
- if (model.patch_bias) {
- inp = ggml_add(ctx0, inp, model.patch_bias);
- }
-
ggml_tensor * inpL = inp;
ggml_tensor * window_mask = nullptr;
ggml_tensor * window_idx = nullptr;
return gf;
}
+ ggml_cgraph * build_internvl() {
+ GGML_ASSERT(model.class_embedding != nullptr);
+ GGML_ASSERT(model.position_embeddings != nullptr);
+
+ const int n_pos = n_patches + 1;
+ ggml_tensor * inp = build_inp();
+
+ // add CLS token
+ inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
+
+ ggml_tensor * cur = build_vit(
+ inp, n_pos,
+ NORM_TYPE_NORMAL,
+ hparams.ffn_op,
+ model.position_embeddings,
+ nullptr);
+
+ // remove CLS token
+ cur = ggml_view_2d(ctx0, cur,
+ n_embd, n_patches,
+ ggml_row_size(cur->type, n_embd), 0);
+
+ // pixel shuffle
+ {
+ const int scale_factor = model.hparams.proj_scale_factor;
+ const int bsz = 1; // batch size, always 1 for now since we don't support batching
+ const int height = n_patches_y;
+ const int width = n_patches_x;
+ GGML_ASSERT(scale_factor > 0);
+ cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+ cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
+ n_embd * scale_factor * scale_factor,
+ height / scale_factor,
+ width / scale_factor,
+ bsz);
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+ // flatten to 2D
+ cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
+ n_embd * scale_factor * scale_factor,
+ cur->ne[1] * cur->ne[2]);
+ }
+
+ // projector (always using GELU activation)
+ {
+ // projector LayerNorm uses pytorch's default eps = 1e-5
+ // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79
+ cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1);
+ cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+ cur = ggml_add(ctx0, cur, model.mm_1_b);
+ cur = ggml_gelu(ctx0, cur);
+ cur = ggml_mul_mat(ctx0, model.mm_3_w, cur);
+ cur = ggml_add(ctx0, cur, model.mm_3_b);
+ }
+
+ // build the graph
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
// this graph is used by llava, granite and glm
// due to having embedding_stack (used by granite), we cannot reuse build_vit
ggml_cgraph * build_llava() {
ggml_tensor * inp = build_inp();
- if (model.patch_bias) {
- inp = ggml_add(ctx0, inp, model.patch_bias);
- }
-
// concat class_embeddings and patch_embeddings
if (model.class_embedding) {
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
ggml_tensor * learned_pos_embd,
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos
) {
- if (model.patch_bias) {
- inp = ggml_add(ctx0, inp, model.patch_bias);
- cb(inp, "patch_bias", -1);
- }
-
if (learned_pos_embd) {
inp = ggml_add(ctx0, inp, learned_pos_embd);
cb(inp, "pos_embed", -1);
cb(cur, "attn_out", il);
}
+ if (layer.ls_1_w) {
+ cur = ggml_mul(ctx0, cur, layer.ls_1_w);
+ cb(cur, "attn_out_scaled", il);
+ }
+
// re-add the layer input, e.g., residual
cur = ggml_add(ctx0, cur, inpL);
cb(cur, "ffn_out", il);
+ if (layer.ls_2_w) {
+ cur = ggml_mul(ctx0, cur, layer.ls_2_w);
+ cb(cur, "ffn_out_scaled", il);
+ }
+
// residual 2
cur = ggml_add(ctx0, inpL, cur);
cb(cur, "layer_out", il);
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
+ if (model.patch_bias) {
+ inp = ggml_add(ctx0, inp, model.patch_bias);
+ cb(inp, "patch_bias", -1);
+ }
return inp;
}
{
res = graph.build_minicpmv();
} break;
+ case PROJECTOR_TYPE_INTERNVL:
+ {
+ res = graph.build_internvl();
+ } break;
default:
{
res = graph.build_llava();
}
} break;
case PROJECTOR_TYPE_IDEFICS3:
+ case PROJECTOR_TYPE_INTERNVL:
{
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
} break;
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
+ layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias
+ layer.ls_2_w = get_tensor(string_format(TN_LS_2, "v", il, "weight"), false); // no bias
+
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
- // new naming
+ // ffn
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
} break;
+ case PROJECTOR_TYPE_INTERNVL:
+ {
+ vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
+ vision_model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
+ vision_model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+ vision_model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
+ vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
+ vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
+ } break;
default:
GGML_ASSERT(false && "unknown projector type");
}
}
else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
- || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
+ || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3
+ || ctx->proj_type == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
+ ) {
clip_image_u8 resized_image;
int sz = params.image_size;
image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
- if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
+ if (ctx->proj_type == PROJECTOR_TYPE_LDP
+ || ctx->proj_type == PROJECTOR_TYPE_LDPV2
+ || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
n_patches /= 4;
- n_patches += 2; // for BOI and EOI token embeddings
+ if (ctx->vision_model.mm_glm_tok_boi) {
+ n_patches += 2; // for BOI and EOI token embeddings
+ }
} else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
if (ctx->minicpmv_version == 2) {
n_patches = 96;
int n_per_side = params.image_size / params.patch_size;
int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
n_patches = n_per_side_2d_pool * n_per_side_2d_pool;
- } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
+ } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_INTERNVL) {
+ // both W and H are divided by proj_scale_factor
n_patches /= (params.proj_scale_factor * params.proj_scale_factor);
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
int n_merge = params.spatial_merge_size;
} break;
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_IDEFICS3:
+ case PROJECTOR_TYPE_INTERNVL:
{
// do nothing
} break;
// the last node is the embedding tensor
ggml_tensor * embeddings = ggml_graph_node(gf, -1);
+ // sanity check (only support batch size of 1 for now)
+ const int n_tokens_out = embeddings->ne[1];
+ const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
+ if (n_tokens_out != expected_n_tokens_out) {
+ LOG_ERR("%s: expected %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
+ GGML_ABORT("Invalid number of output tokens");
+ }
+
// copy the embeddings to the location passed by the user
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
return ctx->vision_model.mm_input_proj_w->ne[0];
case PROJECTOR_TYPE_IDEFICS3:
return ctx->vision_model.projection->ne[1];
+ case PROJECTOR_TYPE_INTERNVL:
+ return ctx->vision_model.mm_3_w->ne[1];
default:
GGML_ABORT("Unknown projector type");
}