llama-infill \
llama-llava-cli \
llama-minicpmv-cli\
+ llama-qwen2vl-cli\
llama-lookahead \
llama-lookup \
llama-lookup-create \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual
+llama-qwen2vl-cli: examples/llava/qwen2vl-cli.cpp \
+ examples/llava/llava.cpp \
+ examples/llava/llava.h \
+ examples/llava/clip.cpp \
+ examples/llava/clip.h \
+ $(OBJ_ALL)
+ $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual
+
ifeq ($(UNAME_S),Darwin)
swift: examples/batched.swift
(cd examples/batched.swift; make build)
- [x] [Mini CPM](https://huggingface.co/models?search=MiniCPM)
- [x] [Moondream](https://huggingface.co/vikhyatk/moondream2)
- [x] [Bunny](https://github.com/BAAI-DCAI/Bunny)
+- [x] [Qwen2-VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)
</details>
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
+@Model.register("Qwen2VLForConditionalGeneration")
+class Qwen2VLModel(Model):
+ model_arch = gguf.MODEL_ARCH.QWEN2VL
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ mrope_section = self.hparams["rope_scaling"]["mrope_section"]
+ mrope_section += [0] * max(0, 4 - len(mrope_section))
+ self.gguf_writer.add_rope_dimension_sections(mrope_section)
+
+ def set_vocab(self):
+ try:
+ self._set_vocab_sentencepiece()
+ except FileNotFoundError:
+ self._set_vocab_gpt2()
+
+ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
+ for name, data in super().get_tensors():
+ if name.startswith("visual."):
+ continue
+ yield name, data
+
+
@Model.register("Qwen2MoeForCausalLM")
class Qwen2MoeModel(Model):
model_arch = gguf.MODEL_ARCH.QWEN2MOE
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
+
+set(TARGET llama-qwen2vl-cli)
+add_executable(${TARGET} qwen2vl-cli.cpp)
+set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-qwen2vl-cli)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
+#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
#define KEY_USE_GELU "clip.use_gelu"
+#define KEY_USE_SILU "clip.use_silu"
#define KEY_N_EMBD "clip.%s.embedding_length"
#define KEY_N_FF "clip.%s.feed_forward_length"
#define KEY_N_BLOCK "clip.%s.block_count"
#define TN_TOKEN_EMBD "%s.token_embd.weight"
#define TN_POS_EMBD "%s.position_embd.weight"
#define TN_CLASS_EMBD "v.class_embd"
-#define TN_PATCH_EMBD "v.patch_embd.weight"
+#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
+#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
#define TN_PATCH_BIAS "v.patch_embd.bias"
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
PROJECTOR_TYPE_LDP,
PROJECTOR_TYPE_LDPV2,
PROJECTOR_TYPE_RESAMPLER,
+ PROJECTOR_TYPE_MERGER,
PROJECTOR_TYPE_UNKNOWN,
};
{ PROJECTOR_TYPE_LDP, "ldp" },
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
+ { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
};
// embeddings
struct ggml_tensor * class_embedding;
- struct ggml_tensor * patch_embeddings;
+ struct ggml_tensor * patch_embeddings_0;
+ struct ggml_tensor * patch_embeddings_1; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
struct ggml_tensor * patch_bias;
struct ggml_tensor * position_embeddings;
bool has_vision_encoder = false;
bool has_llava_projector = false;
bool has_minicpmv_projector = false;
+ bool has_qwen2vl_merger = false;
int minicpmv_version = 2;
struct clip_vision_model vision_model;
float image_mean[3];
float image_std[3];
bool use_gelu = false;
+ bool use_silu = false;
int32_t ftype = 1;
bool has_class_embedding = true;
image_size_height = imgs->data->ny;
}
}
+ else if (ctx->has_qwen2vl_merger) {
+ // use the image's native resolution when image is avaible
+ if (is_inf) {
+ // if (imgs->data->nx && imgs->data->ny) {
+ image_size_width = imgs->data->nx;
+ image_size_height = imgs->data->ny;
+ }
+ }
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
+ const int patches_w = image_size_width / patch_size;
+ const int patches_h = image_size_height / patch_size;
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
+ const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions;
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
int n_layer = hparams.n_layer;
const float eps = hparams.eps;
+ int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
const int batch_size = imgs->size;
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
- struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+ struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
- inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
- inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
+ if (ctx->has_qwen2vl_merger) {
+ GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
+ GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
+
+ auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+ inp = ggml_add(ctx0, inp, inp_1);
+ inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
+ inp = ggml_reshape_4d(
+ ctx0, inp,
+ hidden_size * 2, patches_w / 2, patches_h, batch_size);
+ inp = ggml_reshape_4d(
+ ctx0, inp,
+ hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
+ inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
+ inp = ggml_reshape_3d(
+ ctx0, inp,
+ hidden_size, patches_w * patches_h, batch_size);
+ }
+ else {
+ inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
+ inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
+ }
if (ctx->has_patch_bias) {
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
}
}
- struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
+ struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
ggml_set_name(positions, "positions");
ggml_set_input(positions);
- embeddings =
- ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
+ if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding
+ embeddings =
+ ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
+ }
if (ctx->has_minicpmv_projector) {
int pos_w = image_size_width/patch_size;
}
// loop over layers
- if (ctx->has_minicpmv_projector) {
+ if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) {
+ // TODO: figure out why we doing thing in this way ???
n_layer += 1;
}
for (int il = 0; il < n_layer - 1; il++) {
struct ggml_tensor * Q =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
- Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
+ if (ctx->has_qwen2vl_merger) {
+ Q = ggml_rope_multi(
+ ctx0, Q, positions, nullptr,
+ d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+ }
+ Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
+ if (ctx->has_qwen2vl_merger) {
+ K = ggml_rope_multi(
+ ctx0, K, positions, nullptr,
+ d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+ }
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
if (ctx->use_gelu) {
cur = ggml_gelu_inplace(ctx0, cur);
+ } else if (ctx->use_silu) {
+ cur = ggml_silu_inplace(ctx0, cur);
} else {
cur = ggml_gelu_quick_inplace(ctx0, cur);
}
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
+
}
// post-layernorm
GGML_ASSERT(false);
}
}
+ else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
+ embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
+
+ embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+ embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+
+ // GELU activation
+ embeddings = ggml_gelu(ctx0, embeddings);
+
+ // Second linear layer
+ embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
+ embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+ }
// build the graph
ggml_build_forward_expand(gf, embeddings);
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
}
+ idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER);
+ if (idx != -1) {
+ new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx);
+ }
// GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
GGML_ASSERT(new_clip->has_vision_encoder);
idx = get_key_idx(ctx, KEY_USE_GELU);
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
+ try {
+ idx = get_key_idx(ctx, KEY_USE_SILU);
+ new_clip->use_silu = gguf_get_val_bool(ctx, idx);
+ } catch (std::runtime_error & /*e*/) {
+ new_clip->use_silu = false;
+ }
+
if (verbosity >= 1) {
LOG_INF("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
}
try {
- vision_model.patch_embeddings = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
+ vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
} catch(const std::exception& /*e*/) {
LOG_ERR("%s: failed to load vision model tensors\n", __func__);
}
+ try {
+ vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
+ } catch(const std::exception& /*e*/) {
+ new_clip->has_qwen2vl_merger = false;
+ }
// LLaVA projection
if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) {
vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight"));
vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias"));
}
+ else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) {
+ vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
+ vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
+ vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
+ vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
+ }
else {
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
clip_image_f32_batch batch;
batch.size = 1;
+ batch.data = nullptr;
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false);
ggml_gallocr_reserve(new_clip->compute_alloc, gf);
size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
ctx_clip->load_image_size = load_image_size;
}
+struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) {
+ return ctx_clip->load_image_size;
+}
+
struct clip_image_size * clip_image_size_init() {
struct clip_image_size * load_image_size = new struct clip_image_size();
load_image_size->width = 448;
}
return true;
}
+ else if (ctx->has_qwen2vl_merger) {
+ clip_image_u8 * resized = clip_image_u8_init();
+ auto patch_size = clip_patch_size(ctx) * 2;
+ int nx = ceil((float)img->nx / patch_size) * patch_size;
+ int ny = ceil((float)img->ny / patch_size) * patch_size;
+ bicubic_resize(*img, *resized, nx, ny);
+
+ res_imgs->data = new clip_image_f32[1];
+ // clip_image_f32 * res = clip_image_f32_init();
+ normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std);
+ // res_imgs->data[0] = *res;
+ res_imgs->size = 1;
+
+ // clip_image_f32_free(res);
+ clip_image_u8_free(resized);
+ return true;
+ }
bool pad_to_square = true;
if (!ctx->has_vision_encoder) {
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
}
+size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
+ clip_image_f32 img;
+ img.nx = img_w;
+ img.ny = img_h;
+ return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
+}
+
int32_t clip_image_size(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.image_size;
}
}
int clip_n_patches(const struct clip_ctx * ctx) {
+ clip_image_f32 img;
+ img.nx = ctx->vision_model.hparams.image_size;
+ img.ny = ctx->vision_model.hparams.image_size;
+ return clip_n_patches_by_img(ctx, &img);
+}
+
+int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
const auto & params = ctx->vision_model.hparams;
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
else if (ctx->minicpmv_version == 3) {
n_patches = 64;
}
+ } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
+ int patch_size = params.patch_size * 2;
+ int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
+ int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
+ n_patches = x_patch * y_patch;
}
return n_patches;
const int image_size = hparams.image_size;
int image_size_width = image_size;
int image_size_height = image_size;
- if (ctx->has_minicpmv_projector) {
+ if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) {
image_size_width = imgs->data[0].nx;
image_size_height = imgs->data[0].ny;
}
for (size_t i = 0; i < imgs->size; i++) {
const int nx = imgs->data[i].nx;
const int ny = imgs->data[i].ny;
- if (!ctx->has_minicpmv_projector) {
+ if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
GGML_ASSERT(nx == image_size && ny == image_size);
}
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
- for(int i=0;i<pos_w * pos_h;++i){
- for(int j=0;j<embed_dim;++j){
- pos_embed_data[i*embed_dim+j]=pos_embed_t[i][j];
+ for(int i=0;i < pos_w * pos_h; ++i){
+ for(int j=0; j < embed_dim; ++j){
+ pos_embed_data[i * embed_dim + j] = pos_embed_t[i][j];
}
}
}
}
- {
+ if (ctx->has_qwen2vl_merger) {
+ struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+
+ const int pw = image_size_width / patch_size;
+ const int ph = image_size_height / patch_size;
+ int* positions_data = (int*)malloc(ggml_nbytes(positions));
+
+ int ptr = 0;
+ for (int y = 0; y < ph; y+=2)
+ {
+ for (int x = 0; x < pw; x+=2)
+ {
+ for (int dy = 0; dy < 2; dy++) {
+ for (int dx = 0; dx < 2; dx++) {
+ positions_data[ptr] = y + dy;
+ positions_data[num_patches + ptr] = x + dx;
+ positions_data[num_patches * 2 + ptr] = y + dy;
+ positions_data[num_patches * 3 + ptr] = x + dx;
+ ptr++;
+ }
+ }
+ }
+ }
+
+ ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
+ free(positions_data);
+ }
+ else {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
int* positions_data = (int*)malloc(ggml_nbytes(positions));
}
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
- }
- {
- struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
- int* patches_data = (int*)malloc(ggml_nbytes(patches));
- for (int i = 0; i < num_patches; i++) {
- patches_data[i] = i + 1;
+ {
+ struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
+ int* patches_data = (int*)malloc(ggml_nbytes(patches));
+ for (int i = 0; i < num_patches; i++) {
+ patches_data[i] = i + 1;
+ }
+ ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
+ free(patches_data);
}
- ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
- free(patches_data);
}
}
return 3584;
}
}
+ if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
+ return ctx->vision_model.mm_1_b->ne[0];
+ }
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
}
return 0;
}
+
+bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
+ return ctx->has_qwen2vl_merger;
+}
+
+
+bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
+ clip_image_f32 clip_img;
+ clip_img.buf.resize(h * w * 3);
+ for (int i = 0; i < h*w*3; i++)
+ {
+ clip_img.buf[i] = img[i];
+ }
+ clip_img.nx = w;
+ clip_img.ny = h;
+ clip_image_encode(ctx, n_threads, &clip_img, vec);
+ return true;
+}
CLIP_API void clip_free(struct clip_ctx * ctx);
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
+CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
-CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
-CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
+CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
+CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
+CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);
CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
+CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
CLIP_API struct clip_image_size * clip_image_size_init();
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
+CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
+
+CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
#ifdef __cplusplus
}
const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip);
- if (clip_is_minicpmv(ctx_clip)) {
+ if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) {
std::vector<float *> image_embd_v;
image_embd_v.resize(img_res_v.size);
struct clip_image_size * load_image_size = clip_image_size_init();
+
for (size_t i = 0; i < img_res_v.size; i++) {
const int64_t t_img_enc_step_start_us = ggml_time_us();
- image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip));
+ image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
int patch_size=14;
load_image_size->width = img_res_v.data[i].nx;
load_image_size->height = img_res_v.data[i].ny;
clip_add_load_image_size(ctx_clip, load_image_size);
+
bool encoded = false;
- int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
- if (has_minicpmv_projector == 2) {
- encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
- }
- else if (has_minicpmv_projector == 3) {
+ if (clip_is_qwen2vl(ctx_clip)) {
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
}
+ else {
+ int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
+ if (has_minicpmv_projector == 2) {
+ encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
+ }
+ else if (has_minicpmv_projector == 3) {
+ encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
+ }
+ }
+
if (!encoded) {
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
return false;
int n_img_pos_out = 0;
for (size_t i = 0; i < image_embd_v.size(); i++) {
- std::memcpy(image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), image_embd_v[i], clip_embd_nbytes(ctx_clip));
- n_img_pos_out += clip_n_patches(ctx_clip);
+ std::memcpy(
+ image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
+ image_embd_v[i],
+ clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
+ n_img_pos_out += clip_n_patches_by_img(ctx_clip, &img_res_v.data[i]);
}
*n_img_pos = n_img_pos_out;
for (size_t i = 0; i < image_embd_v.size(); i++) {
if (clip_is_minicpmv(ctx_clip)) {
num_max_patches = 10;
}
- float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model
+ float * image_embd;
+ if (clip_is_qwen2vl(ctx_clip)) {
+ // qwen2vl don't split image into chunks, so `num_max_patches` is not needed.
+ image_embd = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img->nx, img->ny));
+ } else {
+ image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model
+ }
if (!image_embd) {
LOG_ERR("Unable to allocate memory for image embeddings\n");
return false;
--- /dev/null
+import argparse
+from typing import Dict
+
+import torch
+import numpy as np
+from gguf import *
+from transformers import (
+ Qwen2VLForConditionalGeneration,
+ Qwen2VLProcessor,
+ AutoProcessor,
+ Qwen2VLConfig
+)
+
+
+VISION = "clip.vision"
+
+
+def k(raw_key: str, arch: str) -> str:
+ return raw_key.format(arch=arch)
+
+
+def to_gguf_name(name: str) -> str:
+ og = name
+ name = name.replace("text_model", "t").replace("vision_model", "v")
+ name = name.replace("blocks", "blk").replace("embeddings.", "")
+ name = name.replace("attn.", "attn_")
+ name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
+ # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
+ name = name.replace("norm1", "ln1").replace("norm2", "ln2")
+ name = name.replace("merger.mlp", 'mm')
+ print(f"[to_gguf_name] {og} --> {name}")
+ return name
+
+
+def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
+ vision_model = qwen2vl.visual
+ tensor_map = {}
+ for name, ten in vision_model.state_dict().items():
+ ten = ten.numpy()
+ if 'qkv' in name:
+ if ten.ndim == 2: # weight
+ c3, _ = ten.shape
+ else: # bias
+ c3 = ten.shape[0]
+ assert c3 % 3 == 0
+ c = c3 // 3
+ wq = ten[:c]
+ wk = ten[c: c * 2]
+ wv = ten[c * 2:]
+ tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
+ tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
+ tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
+ elif 'merger' in name:
+ if name.endswith("ln_q.weight"):
+ tensor_map['v.post_ln.weight'] = ten
+ elif name.endswith("ln_q.bias"):
+ tensor_map['v.post_ln.bias'] = ten
+ else:
+ # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
+ tensor_map[to_gguf_name(name)] = ten
+ elif 'patch_embed.proj.weight' in name:
+ # NOTE: split Conv3D into Conv2Ds
+ c1, c2, kt, kh, kw = ten.shape
+ assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
+ tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
+ tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
+ else:
+ tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
+
+ for new_name, ten in tensor_map.items():
+ if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
+ tensor_map[new_name] = ten.astype(np.float32)
+ else:
+ tensor_map[new_name] = ten.astype(dtype)
+ tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
+ return tensor_map
+
+
+def main(args):
+ if args.data_type == 'fp32':
+ dtype = torch.float32
+ np_dtype = np.float32
+ ftype = 0
+ elif args.data_type == 'fp16':
+ dtype = torch.float32
+ np_dtype = np.float16
+ ftype = 1
+ else:
+ raise ValueError()
+
+ model_name = args.model_name
+ print("model_name: ", model_name)
+ qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
+ model_name, torch_dtype=dtype, device_map="cpu"
+ )
+ cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
+ vcfg = cfg.vision_config
+
+ if os.path.isdir(model_name):
+ if model_name.endswith(os.sep):
+ model_name = model_name[:-1]
+ model_name = os.path.basename(model_name)
+ fname_out = f"{model_name.replace('/', '-').lower()}-vision.gguf"
+
+ fout = GGUFWriter(path=fname_out, arch="clip")
+ fout.add_description("image encoder for Qwen2VL")
+
+ fout.add_file_type(ftype)
+ fout.add_bool("clip.has_text_encoder", False)
+ fout.add_bool("clip.has_vision_encoder", True)
+ fout.add_bool("clip.has_qwen2vl_merger", True)
+ fout.add_string("clip.projector_type", "qwen2vl_merger")
+
+ print(cfg.vision_config)
+ if 'silu' in cfg.vision_config.hidden_act.lower():
+ fout.add_bool("clip.use_silu", True)
+ fout.add_bool("clip.use_gelu", False)
+ elif 'gelu' in cfg.vision_config.hidden_act.lower():
+ fout.add_bool("clip.use_silu", False)
+ fout.add_bool("clip.use_gelu", 'quick' not in cfg.vision_config.hidden_act.lower())
+ else:
+ raise ValueError()
+
+ tensor_map = find_vision_tensors(qwen2vl, np_dtype)
+ for name, data in tensor_map.items():
+ fout.add_tensor(name, data)
+
+ fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
+ fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2)
+ fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
+ fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
+ fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
+ fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
+ fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
+ fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # not sure what this does, put 0 here as a placeholder
+ fout.add_name(model_name)
+ """
+ HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig,
+ it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`.
+ """
+
+ processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name)
+ fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue]
+ fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue]
+
+ fout.write_header_to_file()
+ fout.write_kv_data_to_file()
+ fout.write_tensors_to_file()
+ fout.close()
+ print("save model as: ", fname_out)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
+ parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
+ args = parser.parse_args()
+ main(args)
--- /dev/null
+#include "arg.h"
+#include "base64.hpp"
+#include "log.h"
+#include "common.h"
+#include "sampling.h"
+#include "clip.h"
+#include "llava.h"
+#include "llama.h"
+#include "ggml.h"
+
+#ifdef GGML_USE_CUDA
+#include "ggml-cuda.h"
+#endif
+#ifdef NDEBUG
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+#endif
+
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <vector>
+#include <algorithm>
+#include <iostream>
+#include <fstream>
+
+
+static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed,
+ int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) {
+ int n_embd = llama_n_embd(llama_get_model(ctx_llama));
+ const int patch_size = 14 * 2;
+ const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0);
+ const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0);
+ auto img_tokens = image_embed->n_image_pos;
+ // llama_pos mrope_pos[img_tokens * 4];
+ std::vector<llama_pos> mrope_pos;
+ mrope_pos.resize(img_tokens * 4);
+
+ for (int y = 0; y < ph; y++)
+ {
+ for (int x = 0; x < pw; x++)
+ {
+ int i = y * pw + x;
+ mrope_pos[i] = *st_pos_id;
+ mrope_pos[i + img_tokens] = *st_pos_id + y;
+ mrope_pos[i + img_tokens * 2] = *st_pos_id + x;
+ mrope_pos[i + img_tokens * 3] = 0;
+ }
+ }
+ *st_pos_id += std::max(pw, ph);
+
+ int processed = 0;
+ std::vector<llama_pos> batch_mrope_pos;
+ batch_mrope_pos.resize(img_tokens * 4);
+
+ for (int i = 0; i < img_tokens; i += n_batch) {
+ int n_eval = img_tokens - i;
+ if (n_eval > n_batch) {
+ n_eval = n_batch;
+ }
+
+ // llama_pos batch_mrope_pos[n_eval * 4];
+ std::fill(batch_mrope_pos.begin(), batch_mrope_pos.end(), 0);
+ memcpy(batch_mrope_pos.data(), &mrope_pos[processed], n_eval * sizeof(llama_pos));
+ memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + processed], n_eval * sizeof(llama_pos));
+ memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
+ memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
+
+ llama_batch batch = {
+ int32_t(n_eval), // n_tokens
+ nullptr, // token
+ (image_embed->embed+i*n_embd), // embed
+ batch_mrope_pos.data(), // pos
+ nullptr, // n_seq_id
+ nullptr, // seq_id
+ nullptr, // logits
+ };
+
+ if (llama_decode(ctx_llama, batch)) {
+ LOG_ERR("%s : failed to eval\n", __func__);
+ return false;
+ }
+ *n_past += n_eval;
+ processed += n_eval;
+ }
+ return true;
+}
+
+
+static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past, int * st_pos_id) {
+ int N = (int) tokens.size();
+ std::vector<llama_pos> pos;
+ for (int i = 0; i < N; i += n_batch) {
+ int n_eval = (int) tokens.size() - i;
+ if (n_eval > n_batch) {
+ n_eval = n_batch;
+ }
+ auto batch = llama_batch_get_one(&tokens[i], n_eval);
+ // TODO: add mrope pos ids somewhere else
+ pos.resize(batch.n_tokens * 4);
+ std::fill(pos.begin(), pos.end(), 0);
+ for (int j = 0; j < batch.n_tokens * 3; j ++) {
+ pos[j] = *st_pos_id + (j % batch.n_tokens);
+ }
+ batch.pos = pos.data();
+
+ if (llama_decode(ctx_llama, batch)) {
+ LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
+ return false;
+ }
+ *n_past += n_eval;
+ *st_pos_id += n_eval;
+ }
+ return true;
+}
+
+static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past, int * st_pos_id) {
+ std::vector<llama_token> tokens;
+ tokens.push_back(id);
+ return eval_tokens(ctx_llama, tokens, 1, n_past, st_pos_id);
+}
+
+static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, int * st_pos_id, bool add_bos){
+ std::string str2 = str;
+ std::vector<llama_token> embd_inp = common_tokenize(ctx_llama, str2, add_bos, true);
+ eval_tokens(ctx_llama, embd_inp, n_batch, n_past, st_pos_id);
+ return true;
+}
+
+static const char * sample(struct common_sampler * smpl,
+ struct llama_context * ctx_llama,
+ int * n_past, int * st_pos_id) {
+ const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
+ common_sampler_accept(smpl, id, true);
+ static std::string ret;
+ if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
+ ret = "</s>";
+ } else {
+ ret = common_token_to_piece(ctx_llama, id);
+ }
+ eval_id(ctx_llama, id, n_past, st_pos_id);
+ return ret.c_str();
+}
+
+static const char* IMG_BASE64_TAG_BEGIN = "<img src=\"data:image/jpeg;base64,";
+static const char* IMG_BASE64_TAG_END = "\">";
+
+static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
+ begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
+ end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
+}
+
+static bool prompt_contains_image(const std::string& prompt) {
+ size_t begin, end;
+ find_image_tag_in_prompt(prompt, begin, end);
+ return (begin != std::string::npos);
+}
+
+// replaces the base64 image tag in the prompt with `replacement`
+static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) {
+ size_t img_base64_str_start, img_base64_str_end;
+ find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
+ if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
+ LOG_ERR("%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
+ return NULL;
+ }
+
+ auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
+ auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
+ auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
+
+ auto required_bytes = base64::required_encode_size(base64_str.size());
+ auto img_bytes = std::vector<unsigned char>(required_bytes);
+ base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
+
+ auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size());
+ if (!embed) {
+ LOG_ERR("%s: could not load image from base64 string.\n", __func__);
+ return NULL;
+ }
+
+ return embed;
+}
+
+static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
+ size_t begin, end;
+ find_image_tag_in_prompt(prompt, begin, end);
+ if (begin == std::string::npos || end == std::string::npos) {
+ return prompt;
+ }
+ auto pre = prompt.substr(0, begin);
+ auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END));
+ return pre + replacement + post;
+}
+
+struct llava_context {
+ struct clip_ctx * ctx_clip = NULL;
+ struct llama_context * ctx_llama = NULL;
+ struct llama_model * model = NULL;
+};
+
+static void print_usage(int, char ** argv) {
+ LOG("\n example usage:\n");
+ LOG("\n %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
+ LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
+}
+
+static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) {
+
+ // load and preprocess the image
+ llava_image_embed * embed = NULL;
+ auto prompt = params->prompt;
+ if (prompt_contains_image(prompt)) {
+ if (!params->image.empty()) {
+ LOG_INF("using base64 encoded image instead of command line image path\n");
+ }
+ embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt);
+ if (!embed) {
+ LOG_ERR("%s: can't load image from prompt\n", __func__);
+ return NULL;
+ }
+ params->prompt = remove_image_from_prompt(prompt);
+ } else {
+ embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str());
+ if (!embed) {
+ fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
+ return NULL;
+ }
+ }
+
+ return embed;
+}
+
+static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) {
+ int n_past = 0;
+ int cur_pos_id = 0;
+
+ const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
+
+ std::string system_prompt, user_prompt;
+ size_t image_pos = prompt.find("<|vision_start|>");
+ if (image_pos != std::string::npos) {
+ // new templating mode: Provide the full prompt including system message and use <image> as a placeholder for the image
+ system_prompt = prompt.substr(0, image_pos);
+ user_prompt = prompt.substr(image_pos + std::string("<|vision_pad|>").length());
+ LOG_INF("system_prompt: %s\n", system_prompt.c_str());
+ if (params->verbose_prompt) {
+ auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true);
+ for (int i = 0; i < (int) tmp.size(); i++) {
+ LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
+ }
+ }
+ LOG_INF("user_prompt: %s\n", user_prompt.c_str());
+ if (params->verbose_prompt) {
+ auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
+ for (int i = 0; i < (int) tmp.size(); i++) {
+ LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
+ }
+ }
+ } else {
+ // llava-1.5 native mode
+ system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|>";
+ user_prompt = "<|vision_end|>" + prompt + "<|im_end|>\n<|im_start|>assistant\n";
+ if (params->verbose_prompt) {
+ auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
+ for (int i = 0; i < (int) tmp.size(); i++) {
+ LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
+ }
+ }
+ }
+
+ eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, true);
+ if (image_embed != nullptr) {
+ auto image_size = clip_get_load_image_size(ctx_llava->ctx_clip);
+ qwen2vl_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past, &cur_pos_id, image_size);
+ }
+ eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, false);
+
+ // generate the response
+
+ LOG("\n");
+
+ struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
+ if (!smpl) {
+ LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
+ exit(1);
+ }
+
+ std::string response = "";
+ for (int i = 0; i < max_tgt_len; i++) {
+ const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past, &cur_pos_id);
+ response += tmp;
+ if (strcmp(tmp, "</s>") == 0) break;
+ if (strstr(tmp, "###")) break; // Yi-VL behavior
+ LOG("%s", tmp);
+ if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
+ if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6
+ if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6
+
+ fflush(stdout);
+ }
+
+ common_sampler_free(smpl);
+ LOG("\n");
+}
+
+static struct llama_model * llava_init(common_params * params) {
+ llama_backend_init();
+ llama_numa_init(params->numa);
+
+ llama_model_params model_params = common_model_params_to_llama(*params);
+
+ llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params);
+ if (model == NULL) {
+ LOG_ERR("%s: unable to load model\n" , __func__);
+ return NULL;
+ }
+ return model;
+}
+
+static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
+ const char * clip_path = params->mmproj.c_str();
+
+ auto prompt = params->prompt;
+ if (prompt.empty()) {
+ prompt = "describe the image in detail.";
+ }
+
+ auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
+
+
+ llama_context_params ctx_params = common_context_params_to_llama(*params);
+ ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
+
+ llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
+
+ if (ctx_llama == NULL) {
+ LOG_ERR("%s: failed to create the llama_context\n" , __func__);
+ return NULL;
+ }
+
+ auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
+
+ ctx_llava->ctx_llama = ctx_llama;
+ ctx_llava->ctx_clip = ctx_clip;
+ ctx_llava->model = model;
+ return ctx_llava;
+}
+
+static void llava_free(struct llava_context * ctx_llava) {
+ if (ctx_llava->ctx_clip) {
+ clip_free(ctx_llava->ctx_clip);
+ ctx_llava->ctx_clip = NULL;
+ }
+
+ llama_free(ctx_llava->ctx_llama);
+ llama_free_model(ctx_llava->model);
+ llama_backend_free();
+}
+
+#ifndef NDEBUG
+
+static void debug_test_mrope_2d() {
+ // 1. Initialize backend
+ ggml_backend_t backend = NULL;
+ std::string backend_name = "";
+#ifdef GGML_USE_CUDA
+ fprintf(stderr, "%s: using CUDA backend\n", __func__);
+ backend = ggml_backend_cuda_init(0); // init device 0
+ backend_name = "cuda";
+ if (!backend) {
+ fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+ }
+#endif
+ // if there aren't GPU Backends fallback to CPU backend
+ if (!backend) {
+ backend = ggml_backend_cpu_init();
+ backend_name = "cpu";
+ }
+
+ // Calculate the size needed to allocate
+ size_t ctx_size = 0;
+ ctx_size += 2 * ggml_tensor_overhead(); // tensors
+ // no need to allocate anything else!
+
+ // 2. Allocate `ggml_context` to store tensor data
+ struct ggml_init_params params = {
+ /*.mem_size =*/ ctx_size,
+ /*.mem_buffer =*/ NULL,
+ /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors()
+ };
+ struct ggml_context * ctx = ggml_init(params);
+
+ struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 128, 12, 30);
+ ggml_set_name(inp_raw, "inp_raw");
+ ggml_set_input(inp_raw);
+
+ struct ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 30 * 4);
+ ggml_set_name(pos, "pos");
+ ggml_set_input(pos);
+
+ std::vector<float> dummy_q;
+ dummy_q.resize(128 * 12 * 30);
+ std::fill(dummy_q.begin(), dummy_q.end(), 0.1);
+ // memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw));
+
+ std::vector<int> pos_id;
+ pos_id.resize(30 * 4);
+ for (int i = 0; i < 30; i ++) {
+ pos_id[i] = i;
+ pos_id[i + 30] = i + 10;
+ pos_id[i + 60] = i + 20;
+ pos_id[i + 90] = i + 30;
+ }
+ int sections[4] = {32, 32, 0, 0};
+
+ // 4. Allocate a `ggml_backend_buffer` to store all tensors
+ ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
+
+ // 5. Copy tensor data from main memory (RAM) to backend buffer
+ ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw));
+ ggml_backend_tensor_set(pos, pos_id.data(), 0, ggml_nbytes(pos));
+
+ // 6. Create a `ggml_cgraph` for mul_mat operation
+ struct ggml_cgraph * gf = NULL;
+ struct ggml_context * ctx_cgraph = NULL;
+
+ // create a temporally context to build the graph
+ struct ggml_init_params params0 = {
+ /*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
+ /*.mem_buffer =*/ NULL,
+ /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
+ };
+ ctx_cgraph = ggml_init(params0);
+ gf = ggml_new_graph(ctx_cgraph);
+
+ struct ggml_tensor * result0 = ggml_rope_multi(
+ ctx_cgraph, inp_raw, pos, nullptr,
+ 128/2, sections, LLAMA_ROPE_TYPE_VISION, 32768, 1000000, 1,
+ 0, 1, 32, 1);
+
+ // Add "result" tensor and all of its dependencies to the cgraph
+ ggml_build_forward_expand(gf, result0);
+
+ // 7. Create a `ggml_gallocr` for cgraph computation
+ ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
+ ggml_gallocr_alloc_graph(allocr, gf);
+
+ // 9. Run the computation
+ int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading
+ if (ggml_backend_is_cpu(backend)) {
+ ggml_backend_cpu_set_n_threads(backend, n_threads);
+ }
+ ggml_backend_graph_compute(backend, gf);
+
+ // 10. Retrieve results (output tensors)
+ // in this example, output tensor is always the last tensor in the graph
+ struct ggml_tensor * result = result0;
+ // struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1];
+ float * result_data = (float *)malloc(ggml_nbytes(result));
+ // because the tensor data is stored in device buffer, we need to copy it back to RAM
+ ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result));
+ const std::string bin_file = "mrope_2d_" + backend_name +".bin";
+ std::ofstream outFile(bin_file, std::ios::binary);
+
+ if (outFile.is_open()) {
+ outFile.write(reinterpret_cast<const char*>(result_data), ggml_nbytes(result));
+ outFile.close();
+ std::cout << "Data successfully written to " + bin_file << std::endl;
+ } else {
+ std::cerr << "Error opening file!" << std::endl;
+ }
+
+ free(result_data);
+ // 11. Free memory and exit
+ ggml_free(ctx_cgraph);
+ ggml_gallocr_free(allocr);
+ ggml_free(ctx);
+ ggml_backend_buffer_free(buffer);
+ ggml_backend_free(backend);
+}
+
+static void debug_dump_img_embed(struct llava_context * ctx_llava) {
+ int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama));
+ int ne = n_embd * 4;
+ float vals[56 * 56 * 3];
+ // float embd[ne];
+ std::vector<float> embd;
+ embd.resize(ne);
+
+ for (int i = 0; i < 56*56; i++)
+ {
+ for (int c = 0; c < 3; c++)
+ vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
+ }
+
+ clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd.data());
+
+ std::ofstream outFile("img_embed.bin", std::ios::binary);
+ if (outFile.is_open()) {
+ outFile.write(reinterpret_cast<const char*>(embd.data()), ne * sizeof(float));
+
+ outFile.close();
+ std::cout << "Data successfully written to mrope.bin" << std::endl;
+ } else {
+ std::cerr << "Error opening file!" << std::endl;
+ }
+}
+
+#endif
+
+
+int main(int argc, char ** argv) {
+ ggml_time_init();
+
+ common_params params;
+
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) {
+ return 1;
+ }
+
+ common_init();
+
+ if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) {
+ print_usage(argc, argv);
+ return 1;
+ }
+
+ auto * model = llava_init(¶ms);
+ if (model == NULL) {
+ fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
+ return 1;
+ }
+
+ if (prompt_contains_image(params.prompt)) {
+ auto * ctx_llava = llava_init_context(¶ms, model);
+
+ auto * image_embed = load_image(ctx_llava, ¶ms, "");
+
+ // process the prompt
+ process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
+
+ llama_perf_context_print(ctx_llava->ctx_llama);
+ llava_image_embed_free(image_embed);
+ ctx_llava->model = NULL;
+ llava_free(ctx_llava);
+#ifndef NDEBUG
+ } else if (params.image[0].empty()) {
+ auto ctx_llava = llava_init_context(¶ms, model);
+
+ debug_test_mrope_2d();
+ debug_dump_img_embed(ctx_llava);
+
+ llama_perf_context_print(ctx_llava->ctx_llama);
+ ctx_llava->model = NULL;
+ llava_free(ctx_llava);
+#endif
+ } else {
+ for (auto & image : params.image) {
+ auto * ctx_llava = llava_init_context(¶ms, model);
+
+ auto * image_embed = load_image(ctx_llava, ¶ms, image);
+ if (!image_embed) {
+ LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str());
+ return 1;
+ }
+
+ // process the prompt
+ process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
+
+ llama_perf_context_print(ctx_llava->ctx_llama);
+ llava_image_embed_free(image_embed);
+ ctx_llava->model = NULL;
+ llava_free(ctx_llava);
+ }
+ }
+
+ llama_free_model(model);
+
+ return 0;
+}
#define GGML_EXIT_SUCCESS 0
#define GGML_EXIT_ABORTED 1
-#define GGML_ROPE_TYPE_NEOX 2
+#define GGML_ROPE_TYPE_NEOX 2
+#define GGML_ROPE_TYPE_MROPE 8
+#define GGML_ROPE_TYPE_VISION 24
#define GGUF_MAGIC "GGUF"
float beta_fast,
float beta_slow);
+ GGML_API struct ggml_tensor * ggml_rope_multi(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ int n_dims,
+ int sections[4],
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow);
+
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
struct ggml_context * ctx,
if (*ext_factor != 0) {
return false;
}
+
+ const int mode = ((const int32_t *) op->op_params)[2];
+ if (mode & GGML_ROPE_TYPE_MROPE) {
+ return false;
+ }
+ if (mode & GGML_ROPE_TYPE_VISION) {
+ return false;
+ }
+
return true;
}
case GGML_OP_UPSCALE: {
}
}
+static void ggml_mrope_cache_init(
+ float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
+ float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
+ float * cache, float sin_sign, float theta_scale) {
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
+ float theta_t = theta_base_t;
+ float theta_h = theta_base_h;
+ float theta_w = theta_base_w;
+ float theta_e = theta_base_e; // extra position id for vision encoder
+ int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
+ int sec_w = sections[1] + sections[0];
+ int sec_e = sections[2] + sec_w;
+ GGML_ASSERT(sect_dims <= ne0);
+
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+ const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
+
+ int sector = (i0 / 2) % sect_dims;
+ if (indep_sects) {
+ // compute theta independently for each dim sections
+ // (i.e. reset corresponding theta when `i0` go from one section to another)
+ if (sector == 0) {
+ theta_t = theta_base_t;
+ }
+ else if (sector == sections[0]) {
+ theta_h = theta_base_h;;
+ }
+ else if (sector == sec_w) {
+ theta_w = theta_base_w;
+ }
+ else if (sector == sec_e) {
+ theta_e = theta_base_e;
+ }
+ }
+
+ float theta = theta_t;
+ if (sector >= sections[0] && sector < sec_w) {
+ theta = theta_h;
+ }
+ else if (sector >= sec_w && sector < sec_w + sections[2]) {
+ theta = theta_w;
+ }
+ else if (sector >= sec_w + sections[2]) {
+ theta = theta_e;
+ }
+
+ rope_yarn(
+ theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
+ );
+ cache[i0 + 1] *= sin_sign;
+
+ theta_t *= theta_scale;
+ theta_w *= theta_scale;
+ theta_h *= theta_scale;
+ theta_e *= theta_scale;
+ }
+}
+
static void ggml_compute_forward_rope_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
const struct ggml_tensor * src2 = dst->src[2];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+ int sections[4];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+ memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
GGML_TENSOR_UNARY_OP_LOCALS
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+ if (is_mrope) {
+ GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
+ }
+
+ if (is_vision) {
+ GGML_ASSERT(n_dims == ne0/2);
+ }
const float * freq_factors = NULL;
if (src2 != NULL) {
const int32_t * pos = (const int32_t *) src1->data;
- for (int64_t i3 = 0; i3 < ne3; i3++) {
- for (int64_t i2 = 0; i2 < ne2; i2++) {
- const int64_t p = pos[i2];
+ for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
+ for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+ if (!is_mrope) {
+ const int64_t p = pos[i2];
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+ }
+ else {
+ const int64_t p_t = pos[i2];
+ const int64_t p_h = pos[i2 + ne2];
+ const int64_t p_w = pos[i2 + ne2 * 2];
+ const int64_t p_e = pos[i2 + ne2 * 3];
+ ggml_mrope_cache_init(
+ p_t, p_h, p_w, p_e, sections, is_vision,
+ freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+ }
- for (int64_t i1 = 0; i1 < ne1; i1++) {
+ for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
if (ir++ < ir0) continue;
if (ir > ir1) break;
- if (!is_neox) {
+ if (is_neox || is_mrope) {
+ if (is_vision){
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+ const int64_t ic = i0/2;
+
+ const float cos_theta = cache[i0 + 0];
+ const float sin_theta = cache[i0 + 1];
+
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
+
+ const float x0 = src[0];
+ const float x1 = src[n_dims];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
+ }
+ } else {
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+ const int64_t ic = i0/2;
+
+ const float cos_theta = cache[i0 + 0];
+ const float sin_theta = cache[i0 + 1];
+
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
+
+ const float x0 = src[0];
+ const float x1 = src[n_dims/2];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+ }
+ }
+ } else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
- } else {
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+ }
+
+ if (is_vision) {
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
- const float x1 = src[n_dims/2];
+ const float x1 = src[n_dims];
- dst_data[0] = x0*cos_theta - x1*sin_theta;
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
}
- }
-
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ } else {
+ // fill the remain channels with data from src tensor
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
- dst_data[0] = src[0];
- dst_data[1] = src[1];
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
+ }
}
}
}
const struct ggml_tensor * src2 = dst->src[2];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+ int sections[4];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+ memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
+
GGML_TENSOR_UNARY_OP_LOCALS
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+ if (is_mrope) {
+ GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
+ }
+
+ if (is_vision) {
+ GGML_ASSERT(n_dims == ne0/2);
+ }
const float * freq_factors = NULL;
if (src2 != NULL) {
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
- const int64_t p = pos[i2];
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+ if (!is_mrope) {
+ const int64_t p = pos[i2];
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+ }
+ else {
+ const int64_t p_t = pos[i2];
+ const int64_t p_h = pos[i2 + ne2];
+ const int64_t p_w = pos[i2 + ne2 * 2];
+ const int64_t p_e = pos[i2 + ne2 * 3];
+ ggml_mrope_cache_init(
+ p_t, p_h, p_w, p_e, sections, is_vision,
+ freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+ }
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
- if (!is_neox) {
+ if (is_neox || is_mrope) {
+ if (is_vision) {
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+ const int64_t ic = i0/2;
+
+ const float cos_theta = cache[i0 + 0];
+ const float sin_theta = cache[i0 + 1];
+
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
+
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
+
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+ dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+ }
+ } else {
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+ const int64_t ic = i0/2;
+
+ const float cos_theta = cache[i0 + 0];
+ const float sin_theta = cache[i0 + 1];
+
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
+
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+ }
+ }
+ } else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
- } else {
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+ }
+
+ if (is_vision) {
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]);
- const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
- dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+ dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
- }
-
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ } else {
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
- dst_data[0] = src[0];
- dst_data[1] = src[1];
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
+ }
}
}
}
float v[2];
};
+
+struct mrope_sections {
+ int v[4];
+};
+
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
+template<typename T, bool has_ff>
+static __global__ void rope_multi(
+ const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+ if (i0 >= ne0) {
+ return;
+ }
+
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i0 >= n_dims) {
+ const int i = row*ne0 + i0;
+
+ dst[i + 0] = x[i + 0];
+ dst[i + 1] = x[i + 1];
+
+ return;
+ }
+
+ const int i = row*ne0 + i0/2;
+ const int i2 = row/p_delta_rows;
+
+ int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
+ int sec_w = sections.v[1] + sections.v[0];
+ int sector = (i0 / 2) % sect_dims;
+
+ float theta_base = 0.0;
+ if (sector < sections.v[0]) {
+ theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+ }
+ else if (sector >= sections.v[0] && sector < sec_w) {
+ theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
+ }
+ else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+ theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
+ }
+ else if (sector >= sec_w + sections.v[2]) {
+ theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
+ }
+
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+ float cos_theta;
+ float sin_theta;
+
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ const float x0 = x[i + 0];
+ const float x1 = x[i + n_dims/2];
+
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
+ dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+}
+
+template<typename T, bool has_ff>
+static __global__ void rope_vision(
+ const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+ if (i0 >= ne0) {
+ return;
+ }
+
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+ const int i = row*ne0 + i0/2;
+ const int i2 = row/p_delta_rows; // i2-th tokens
+
+ int sect_dims = sections.v[0] + sections.v[1];
+ int sec_w = sections.v[1] + sections.v[0];
+ int sector = (i0 / 2) % sect_dims;
+
+ float theta_base = 0.0;
+ if (sector < sections.v[0]) {
+ const int p = sector;
+ theta_base = pos[i2]*powf(theta_scale, p);
+ }
+ else if (sector >= sections.v[0] && sector < sec_w) {
+ const int p = sector - sections.v[0];
+ theta_base = pos[i2 + ne2]*powf(theta_scale, p);
+ }
+
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+ float cos_theta;
+ float sin_theta;
+
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ const float x0 = x[i + 0];
+ const float x1 = x[i + n_dims];
+
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
+ dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
+}
+
template<typename T>
static void rope_norm_cuda(
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
}
}
+template<typename T>
+static void rope_multi_cuda(
+ const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
+ GGML_ASSERT(ne0 % 2 == 0);
+ const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+ const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+ const dim3 block_nums(nr, n_blocks_x, 1);
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+ if (freq_factors == nullptr) {
+ rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, freq_factors, sections
+ );
+ } else {
+ rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, freq_factors, sections
+ );
+ }
+}
+
+template<typename T>
+static void rope_vision_cuda(
+ const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
+ GGML_ASSERT(ne0 % 2 == 0);
+ const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+ const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+ const dim3 block_nums(nr, n_blocks_x, 1);
+ // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
+ // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+ if (freq_factors == nullptr) {
+ rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, freq_factors, sections
+ );
+ } else {
+ rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, freq_factors, sections
+ );
+ }
+}
+
static void rope_norm_cuda_f16(
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
+static void rope_multi_cuda_f16(
+ const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
+) {
+
+ rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+}
+
+static void rope_multi_cuda_f32(
+ const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
+) {
+
+ rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+}
+
+static void rope_vision_cuda_f16(
+ const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
+) {
+
+ rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+}
+
+static void rope_vision_cuda_f32(
+ const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
+) {
+
+ rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+}
+
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
+ const int64_t ne00 = src0->ne[0]; // head dims
+ const int64_t ne01 = src0->ne[1]; // num heads
+ const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+ mrope_sections sections;
// RoPE alteration for extended context
float freq_base;
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+ memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+ if (is_mrope) {
+ GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
+ }
+
+ if (is_vision) {
+ GGML_ASSERT(n_dims == ne00/2);
+ }
const int32_t * pos = (const int32_t *) src1_d;
} else {
GGML_ABORT("fatal error");
}
+ } else if (is_mrope && !is_vision) {
+ if (src0->type == GGML_TYPE_F32) {
+ rope_multi_cuda_f32(
+ (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, sections, stream
+ );
+ } else if (src0->type == GGML_TYPE_F16) {
+ rope_multi_cuda_f16(
+ (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, sections, stream
+ );
+ } else {
+ GGML_ABORT("fatal error");
+ }
+ } else if (is_vision) {
+ if (src0->type == GGML_TYPE_F32) {
+ rope_vision_cuda_f32(
+ (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, sections, stream
+ );
+ } else if (src0->type == GGML_TYPE_F16) {
+ rope_vision_cuda_f16(
+ (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, sections, stream
+ );
+ } else {
+ GGML_ABORT("fatal error");
+ }
} else {
if (src0->type == GGML_TYPE_F32) {
rope_norm_cuda_f32(
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
case GGML_OP_NORM:
- case GGML_OP_ROPE:
return true;
+ case GGML_OP_ROPE:
+ {
+ const int mode = ((const int32_t *) op->op_params)[2];
+ if (mode & GGML_ROPE_TYPE_MROPE) {
+ return false;
+ }
+ if (mode & GGML_ROPE_TYPE_VISION) {
+ return false;
+ }
+ return true;
+ }
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
case GGML_OP_ARGMAX:
case GGML_OP_NORM:
- case GGML_OP_ROPE:
return true;
+ case GGML_OP_ROPE:
+ {
+ const int mode = ((const int32_t *) op->op_params)[2];
+ if (mode & GGML_ROPE_TYPE_MROPE) {
+ return false;
+ }
+ if (mode & GGML_ROPE_TYPE_VISION) {
+ return false;
+ }
+ return true;
+ }
case GGML_OP_IM2COL:
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_1D:
} break;
case GGML_OP_ROPE:
{
- GGML_ASSERT(ne10 == ne02);
+ // make sure we have one or more position id(ne10) per token(ne02)
+ GGML_ASSERT(ne10 % ne02 == 0);
+ GGML_ASSERT(ne10 >= ne02);
const int nth = MIN(1024, ne00);
case GGML_OP_SOFT_MAX:
return true;
case GGML_OP_ROPE:
- return ggml_is_contiguous(op->src[0]);
+ {
+ const int mode = ((const int32_t *) op->op_params)[2];
+ if (mode & GGML_ROPE_TYPE_MROPE) {
+ return false;
+ }
+ if (mode & GGML_ROPE_TYPE_VISION) {
+ return false;
+ }
+ return ggml_is_contiguous(op->src[0]);
+ }
case GGML_OP_IM2COL:
// TODO: add support for the new F32 operations
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_REPEAT:
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
case GGML_OP_ROPE:
- return ggml_is_contiguous(op->src[0]);
+ {
+ const int mode = ((const int32_t *) op->op_params)[2];
+ if (mode & GGML_ROPE_TYPE_MROPE) {
+ return false;
+ }
+ if (mode & GGML_ROPE_TYPE_VISION) {
+ return false;
+ }
+ return ggml_is_contiguous(op->src[0]);
+ }
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
GGML_ASSERT(c->ne[0] >= n_dims / 2);
}
+ int sections[4] = {0, 0, 0, 0};
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
- int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
+ int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
+ memcpy(params + 11, §ions, sizeof(int)*4);
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE;
);
}
+struct ggml_tensor * ggml_rope_multi(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ int n_dims,
+ int sections[4],
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow) {
+ // Multimodal Rotary Position Embedding
+ GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
+
+ GGML_ASSERT(ggml_is_vector(b));
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
+ GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
+
+ if (c) {
+ GGML_ASSERT(c->type == GGML_TYPE_F32);
+ GGML_ASSERT(c->ne[0] >= n_dims / 2);
+ }
+
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+ int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
+ memcpy(params + 5, &freq_base, sizeof(float));
+ memcpy(params + 6, &freq_scale, sizeof(float));
+ memcpy(params + 7, &ext_factor, sizeof(float));
+ memcpy(params + 8, &attn_factor, sizeof(float));
+ memcpy(params + 9, &beta_fast, sizeof(float));
+ memcpy(params + 10, &beta_slow, sizeof(float));
+ memcpy(¶ms[11], sections, sizeof(int)*4);
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_ROPE;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = c;
+
+ return result;
+}
+
struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
+ DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
FREQ_BASE = "{arch}.rope.freq_base"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
QWEN = auto()
QWEN2 = auto()
QWEN2MOE = auto()
+ QWEN2VL = auto()
PHI2 = auto()
PHI3 = auto()
PLAMO = auto()
MODEL_ARCH.QWEN: "qwen",
MODEL_ARCH.QWEN2: "qwen2",
MODEL_ARCH.QWEN2MOE: "qwen2moe",
+ MODEL_ARCH.QWEN2VL: "qwen2vl",
MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PLAMO: "plamo",
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
+ MODEL_ARCH.QWEN2VL: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
MODEL_ARCH.QWEN2MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
+ def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
+ self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)
+
def add_rope_freq_base(self, value: float) -> None:
self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
};
enum llama_rope_type {
- LLAMA_ROPE_TYPE_NONE = -1,
- LLAMA_ROPE_TYPE_NORM = 0,
- LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
+ LLAMA_ROPE_TYPE_NONE = -1,
+ LLAMA_ROPE_TYPE_NORM = 0,
+ LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
+ LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
+ LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
};
enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
LLM_ARCH_QWEN,
LLM_ARCH_QWEN2,
LLM_ARCH_QWEN2MOE,
+ LLM_ARCH_QWEN2VL,
LLM_ARCH_PHI2,
LLM_ARCH_PHI3,
LLM_ARCH_PLAMO,
{ LLM_ARCH_QWEN, "qwen" },
{ LLM_ARCH_QWEN2, "qwen2" },
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
+ { LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PLAMO, "plamo" },
LLM_KV_ATTENTION_SCALE,
LLM_KV_ROPE_DIMENSION_COUNT,
+ LLM_KV_ROPE_DIMENSION_SECTIONS,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_SCALING_TYPE,
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
+ { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_QWEN2VL,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_QWEN2MOE,
{
uint32_t time_decay_extra_dim = 0;
uint32_t wkv_head_size = 0;
- float rope_attn_factor = 1.0f;
- float rope_freq_base_train;
- float rope_freq_scale_train;
- uint32_t n_ctx_orig_yarn;
- float rope_yarn_log_mul;
+ float rope_attn_factor = 1.0f;
+ float rope_freq_base_train;
+ float rope_freq_scale_train;
+ uint32_t n_ctx_orig_yarn;
+ float rope_yarn_log_mul;
+ int rope_sections[4];
// for State Space Models
uint32_t ssm_d_conv = 0;
if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true;
+ if (std::equal(std::begin(this->rope_sections),
+ std::end(this->rope_sections),
+ std::begin(other.rope_sections))) return true;
if (this->ssm_d_conv != other.ssm_d_conv) return true;
if (this->ssm_d_inner != other.ssm_d_inner) return true;
// whether we are computing encoder output or decoder output
bool is_encoding = false;
+ // TODO: find a better way to accommodate mutli-dimension position encoding methods
+ // number of position id each token get, 1 for each token in most cases.
+ // when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
+ int n_pos_per_token = 1;
+
// output of the encoder part of the encoder-decoder models
std::vector<float> embd_enc;
std::vector<std::set<llama_seq_id>> seq_ids_enc;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_QWEN2VL:
+ {
+ std::array<int, 4> section_dims;
+ ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, section_dims, 4, true);
+ std::copy(section_dims.begin(), section_dims.begin() + 4, std::begin(hparams.rope_sections));
+ }
+ // fall through
case LLM_ARCH_QWEN2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
}
} break;
case LLM_ARCH_QWEN2:
+ case LLM_ARCH_QWEN2VL:
{
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
return gf;
}
+ struct ggml_cgraph * build_qwen2vl() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+ // inp_pos - contains the positions
+ lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 4);
+ cb(lctx.inp_pos, "inp_pos", -1);
+ ggml_set_input(lctx.inp_pos);
+ struct ggml_tensor * inp_pos = lctx.inp_pos;
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+ int sections[4];
+ std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_rope_multi(
+ ctx0,
+ ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+ n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_rope_multi(
+ ctx0,
+ ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+ n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur", il);
+
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+ model.layers[il].wo, model.layers[il].bo,
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, lctx, cur,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
struct ggml_cgraph * build_qwen2moe() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
{
result = llm.build_qwen2();
} break;
+ case LLM_ARCH_QWEN2VL:
+ {
+ lctx.n_pos_per_token = 4;
+ result = llm.build_qwen2vl();
+ } break;
case LLM_ARCH_QWEN2MOE:
{
result = llm.build_qwen2moe();
if (ubatch.pos && lctx.inp_pos) {
const int64_t n_tokens = ubatch.n_tokens;
-
- ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
+ auto n_pos = lctx.n_pos_per_token;
+ ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos));
}
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
case LLM_ARCH_MINICPM3:
return LLAMA_ROPE_TYPE_NEOX;
+ case LLM_ARCH_QWEN2VL:
+ return LLAMA_ROPE_TYPE_MROPE;
+
// all model arches should be listed explicitly here
case LLM_ARCH_UNKNOWN:
GGML_ABORT("unknown architecture");
ggml_set_name(a, "a");
}
- ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+ ggml_tensor * pos;
+ if (is_mrope || is_vision) {
+ pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);
+ } else {
+ pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
+ }
ggml_set_name(pos, "pos");
ggml_tensor * freq = nullptr;
ggml_set_name(freq, "freq");
}
- ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+ ggml_tensor * out;
+ if (is_mrope) {
+ if (is_vision) {
+ GGML_ASSERT(n_dims/4 > 0);
+ int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
+ out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+ } else {
+ GGML_ASSERT(n_dims/3 > 0);
+ int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
+ out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+ }
+ } else {
+ out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+ }
ggml_set_name(out, "out");
return out;
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
// pos
- std::vector<int> data(ne_a[2]);
- for (int i = 0; i < ne_a[2]; i++) {
+ const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];
+ std::vector<int> data(num_pos_ids);
+ for (int i = 0; i < num_pos_ids; i++) {
data[i] = rand() % n_ctx;
}
- ggml_backend_tensor_set(t, data.data(), 0, ne_a[2] * sizeof(int));
+ ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));
} else {
if (t->ne[0] == n_dims/2) {
// frequency factors in the range [0.9f, 1.1f]
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
}
+ if (all) {
+ test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 2B)
+ test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 7B)
+ test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl ViT)
+ }
+
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
}
}
struct ggml_tensor * x;
// rope f32
- for (int m = 0; m < 3; ++m) {
+ for (int m = 0; m < 5; ++m) {
const int ndims = 4;
const int64_t n_rot = 128;
const int n_past_0 = 100;
const int n_past_2 = 33;
- struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
- struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
- struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
-
- for (int i = 0; i < ne[2]; ++i) {
- ((int32_t *) p0->data)[i] = n_past_0 + i;
- ((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
- ((int32_t *) p2->data)[i] = n_past_2 + i;
- }
-
- // test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
- const int mode = m == 0 ? 0 : m == 1 ? 2 : 4;
-
+ struct ggml_tensor * r0;
+ struct ggml_tensor * r1;
+ struct ggml_tensor * r2;
x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ int mode = -1;
- // 100, 101, 102, ..., 172
- struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode);
- // -67, -67, -67, ..., -67
- struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
+ if (m < 3) {
+ struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
+ struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
+ struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
- // 33, 34, 35, ..., 105
- struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode);
+ for (int i = 0; i < ne[2]; ++i) {
+ ((int32_t *) p0->data)[i] = n_past_0 + i;
+ ((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
+ ((int32_t *) p2->data)[i] = n_past_2 + i;
+ }
+ // test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
+ mode = m == 0 ? 0 : m == 1 ? 2 : 4;
+
+ // 100, 101, 102, ..., 172
+ r0 = ggml_rope(ctx0, x, p0, n_rot, mode);
+ // -67, -67, -67, ..., -67
+ r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
+
+ // 33, 34, 35, ..., 105
+ r2 = ggml_rope(ctx0, x, p2, n_rot, mode);
+ } else {
+ // testing multi-dimension rope position embedding mode
+ struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
+ struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
+ struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
+
+ int sections[4] = {16, 24, 24, 0};
+ mode = (m == 3) ? GGML_ROPE_TYPE_MROPE : GGML_ROPE_TYPE_VISION;
+
+ for (int i = 0; i < ne[2]; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ ((int32_t *) p0->data)[i + ne[2] * j] = n_past_0 + i + j;
+ ((int32_t *) p1->data)[i + ne[2] * j] = n_past_2 - n_past_0;
+ ((int32_t *) p2->data)[i + ne[2] * j] = n_past_2 + i + j;
+ }
+ }
+
+ // [[100, 101, 102, ..., 172],
+ // [101, 102, 103, ..., 173],
+ // [102, 103, 104, ..., 174]]
+ r0 = ggml_rope_multi(
+ ctx0, x, p0, nullptr,
+ n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
+ // [[-67, -67, -67, ..., -67]
+ // [-67, -67, -67, ..., -67]
+ // [-67, -67, -67, ..., -67]]
+ r1 = ggml_rope_multi(
+ ctx0, r0, p1, nullptr,
+ n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
+
+ // [[33, 34, 35, ..., 105]
+ // [34, 35, 36, ..., 106]
+ // [35, 36, 37, ..., 107]]
+ r2 = ggml_rope_multi(
+ ctx0, x, p2, nullptr,
+ n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
+ }
ggml_cgraph * gf = ggml_new_graph(ctx0);