struct clip_hparams {
int32_t image_size;
int32_t patch_size;
- int32_t hidden_size;
- int32_t n_intermediate;
+ int32_t n_embd;
+ int32_t n_ff;
int32_t projection_dim;
int32_t n_head;
int32_t n_layer;
struct ggml_tensor * ln_1_w = nullptr;
struct ggml_tensor * ln_1_b = nullptr;
- // ff
- struct ggml_tensor * ff_i_w = nullptr; // legacy naming
- struct ggml_tensor * ff_i_b = nullptr; // legacy naming
- struct ggml_tensor * ff_o_w = nullptr; // legacy naming
- struct ggml_tensor * ff_o_b = nullptr; // legacy naming
-
struct ggml_tensor * ff_up_w = nullptr;
struct ggml_tensor * ff_up_b = nullptr;
struct ggml_tensor * ff_gate_w = nullptr;
struct ggml_tensor * ff_down_w = nullptr;
struct ggml_tensor * ff_down_b = nullptr;
- struct ggml_tensor * ff_g_w = NULL;
- struct ggml_tensor * ff_g_b = NULL;
-
// layernorm 2
struct ggml_tensor * ln_2_w = nullptr;
struct ggml_tensor * ln_2_b = nullptr;
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
- const int hidden_size = hparams.hidden_size;
+ const int n_embd = hparams.n_embd;
const int n_head = hparams.n_head;
- const int d_head = hidden_size / n_head;
+ const int d_head = n_embd / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
ggml_set_input(inp_raw);
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
- inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
+ inp = ggml_reshape_2d(ctx0, inp, num_patches, n_embd);
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
inp = ggml_add(ctx0, inp, model.patch_bias);
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
+ cur = ggml_cont_2d(ctx0, KQV, n_embd, num_patches);
}
// attention output
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
}
- cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
- cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
+ cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
+ cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b);
// siglip uses gelu
cur = ggml_gelu(ctx0, cur);
- cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
- cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
+ cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
+ cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
// residual 2
cur = ggml_add(ctx0, embeddings, cur);
const int kernel_size = patches_per_image / tokens_per_side;
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
- embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size);
+ embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, n_embd, batch_size);
// doing a pool2d to reduce the number of output tokens to 256
embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
- embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size);
+ embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], n_embd, batch_size);
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
// apply norm before projection
const int n_patches_x = image_size_width / patch_size;
const int n_patches_y = image_size_height / patch_size;
const int num_patches = n_patches_x * n_patches_y;
- const int hidden_size = hparams.hidden_size;
+ const int n_embd = hparams.n_embd;
const int n_head = hparams.n_head;
- const int d_head = hidden_size / n_head;
+ const int d_head = n_embd / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
const int n_merge = hparams.spatial_merge_size;
ggml_set_input(pos_w);
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
- inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
+ inp = ggml_reshape_2d(ctx0, inp, num_patches, n_embd);
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
struct ggml_tensor * embeddings = inp;
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
+ cur = ggml_cont_2d(ctx0, KQV, n_embd, num_patches);
cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
}
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
// reshape image tokens to 2D grid
- cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y);
- cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size]
+ cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y);
+ cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, n_embd]
cur = ggml_cont(ctx0, cur);
// torch.nn.functional.unfold is just an im2col under the hood
ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
- // project to hidden_size
+ // project to n_embd
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
embeddings = cur;
// arrangement of the [IMG_BREAK] token
{
// not efficient, but works
- // the trick is to view the embeddings as a 3D tensor with shape [hidden_size, n_patches_per_row, n_rows]
+ // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
- // after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
+ // after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows]
const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
const int patches_h = image_size_height / patch_size;
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
const int num_position_ids = num_positions * 4; // m-rope requires 4 dim per position
- const int hidden_size = hparams.hidden_size;
+ const int n_embd = hparams.n_embd;
const int n_head = hparams.n_head;
- const int d_head = hidden_size / n_head;
+ const int d_head = n_embd / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
inp = ggml_reshape_4d(
ctx0, inp,
- hidden_size * 2, patches_w / 2, patches_h, batch_size);
+ n_embd * 2, patches_w / 2, patches_h, batch_size);
inp = ggml_reshape_4d(
ctx0, inp,
- hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
+ n_embd * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
inp = ggml_reshape_3d(
ctx0, inp,
- hidden_size, patches_w * patches_h, batch_size);
+ n_embd, patches_w * patches_h, batch_size);
if (model.patch_bias) {
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
ggml_set_name(window_mask, "window_mask");
ggml_set_input(window_mask);
- // embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
+ // embeddings shape: [n_embd, patches_w * patches_h, batch_size]
GGML_ASSERT(batch_size == 1);
- embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
+ embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * 4, patches_w * patches_h * batch_size / 4);
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
- embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
+ embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, patches_w * patches_h, batch_size);
}
// loop over layers
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
+ cur = ggml_cont_3d(ctx0, KQV, n_embd, num_positions, batch_size);
}
// attention output
// mlp
// ffn_up
- auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
- cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
+ auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
+ cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_up_b);
- auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
- cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
+ auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
+ cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_gate_b);
// TODO : only 2 of these 3 are actually used, should we remove one of them?
if (ctx->use_gelu) {
cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
cur = ggml_mul(ctx0, cur_gate, cur_up);
// ffn_down
- cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
- cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
+ cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
+ cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
// residual 2
cur = ggml_add(ctx0, embeddings, cur);
embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
}
- embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
+ embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, num_positions / 4, batch_size);
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
ggml_set_name(window_idx, "window_idx");
ggml_set_input(window_idx);
- // embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
+ // embeddings shape: [n_embd, patches_w * patches_h, batch_size]
GGML_ASSERT(batch_size == 1);
embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
const int patches_h = image_size_height / patch_size;
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
const int num_position_ids = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions;
- const int hidden_size = hparams.hidden_size;
+ const int n_embd = hparams.n_embd;
const int n_head = hparams.n_head;
- const int d_head = hidden_size / n_head;
+ const int d_head = n_embd / n_head;
const float eps = hparams.eps;
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
inp = ggml_reshape_4d(
ctx0, inp,
- hidden_size * 2, patches_w / 2, patches_h, batch_size);
+ n_embd * 2, patches_w / 2, patches_h, batch_size);
inp = ggml_reshape_4d(
ctx0, inp,
- hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
+ n_embd * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
inp = ggml_reshape_3d(
ctx0, inp,
- hidden_size, patches_w * patches_h, batch_size);
+ n_embd, patches_w * patches_h, batch_size);
}
else {
- inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
+ inp = ggml_reshape_3d(ctx0, inp, num_patches, n_embd, batch_size);
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
}
// concat class_embeddings and patch_embeddings
if (model.class_embedding) {
- embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
+ embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, num_positions, batch_size);
embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
+ cur = ggml_cont_3d(ctx0, KQV, n_embd, num_positions, batch_size);
}
// attention output
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
}
- cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
- cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
+ cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
+ cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b);
if (ctx->use_gelu) {
cur = ggml_gelu_inplace(ctx0, cur);
cur = ggml_gelu_quick_inplace(ctx0, cur);
}
- cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
- cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
+ cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
+ cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
// residual 2
cur = ggml_add(ctx0, embeddings, cur);
}
{ // attention
- int hidden_size = clip_n_mmproj_embd(ctx);
+ int n_embd = clip_n_mmproj_embd(ctx);
const int d_head = 128;
- int n_head = hidden_size/d_head;
+ int n_head = n_embd/d_head;
int num_query = 96;
if (ctx->minicpmv_version == 2) {
num_query = 96;
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
+ KQV = ggml_cont_3d(ctx0, KQV, n_embd, num_query, batch_size);
embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
}
}
else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
- embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
+ embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, num_positions / 4, batch_size);
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
- get_u32(KEY_N_EMBD, hparams.hidden_size);
+ get_u32(KEY_N_EMBD, hparams.n_embd);
get_u32(KEY_N_HEAD, hparams.n_head);
- get_u32(KEY_N_FF, hparams.n_intermediate);
+ get_u32(KEY_N_FF, hparams.n_ff);
get_u32(KEY_N_BLOCK, hparams.n_layer);
get_u32(KEY_PROJ_DIM, hparams.projection_dim);
get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
}
void load_tensors() {
+ auto & hparams = ctx_clip.vision_model.hparams;
std::map<std::string, size_t> tensor_offset;
std::vector<ggml_tensor *> tensors_to_load;
vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false);
// layers
- vision_model.layers.resize(vision_model.hparams.n_layer);
- for (int il = 0; il < vision_model.hparams.n_layer; ++il) {
+ vision_model.layers.resize(hparams.n_layer);
+ for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = vision_model.layers[il];
layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight"));
layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight"));
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
- // legacy naming (the in and out is reversed! don't ask me why)
- layer.ff_i_w = layer.ff_down_w;
- layer.ff_o_w = layer.ff_up_w;
- layer.ff_g_w = layer.ff_gate_w;
- layer.ff_i_b = layer.ff_down_b;
- layer.ff_o_b = layer.ff_up_b;
- layer.ff_g_b = layer.ff_gate_b;
+ // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
+ // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
+ if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) {
+ // swap up and down weights
+ ggml_tensor * tmp = layer.ff_up_w;
+ layer.ff_up_w = layer.ff_down_w;
+ layer.ff_down_w = tmp;
+ // swap up and down biases
+ tmp = layer.ff_up_b;
+ layer.ff_up_b = layer.ff_down_b;
+ layer.ff_down_b = tmp;
+ }
}
switch (ctx_clip.proj_type) {
}
int32_t clip_get_hidden_size(const struct clip_ctx * ctx) {
- return ctx->vision_model.hparams.hidden_size;
+ return ctx->vision_model.hparams.n_embd;
}
const char * clip_patch_merge_type(const struct clip_ctx * ctx) {