if (op->src[0]->ne[0] != 32 &&
op->src[0]->ne[0] != 40 &&
op->src[0]->ne[0] != 64 &&
+ op->src[0]->ne[0] != 72 &&
op->src[0]->ne[0] != 80 &&
op->src[0]->ne[0] != 96 &&
op->src[0]->ne[0] != 112 &&
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
}
- for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) {
- for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) {
+ for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
+ for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
#include "clip-impl.h"
#include "ggml.h"
#include "ggml-cpp.h"
-#include "ggml-cpu.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include "gguf.h"
#include <cstring>
#include <fstream>
#include <map>
-#include <regex>
#include <stdexcept>
#include <unordered_set>
#include <vector>
-#include <sstream>
#include <cinttypes>
#include <limits>
#include <array>
-#include <numeric>
#include <functional>
+// TODO: allow to pass callback from user code
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
enum ffn_op_type {
int max_nodes = 8192;
ggml_backend_sched_ptr sched;
+ clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
// for debugging
bool debug_graph = false;
std::vector<ggml_tensor *> debug_print_tensors;
clip_ctx(clip_context_params & ctx_params) {
+ flash_attn_type = ctx_params.flash_attn_type;
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (!backend_cpu) {
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
//cb(k, "k", il);
- ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
- v = ggml_cont(ctx0, v);
- //cb(k, "v", il);
-
ggml_tensor * cur;
- // TODO @ngxson : support flash attention
- {
+ if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
+ ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
+
+ k = ggml_cast(ctx0, k, GGML_TYPE_F16);
+ v = ggml_cast(ctx0, v, GGML_TYPE_F16);
+
+ cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
+ ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
+
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
+
+ } else {
+ ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
+ v = ggml_cont(ctx0, v);
+
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
- // const auto n_kv = k->ne[1]; // for flash attention
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
// F32 may not needed for vision encoders?
}
}
- void alloc_compute_meta(clip_ctx & ctx_clip) {
+ struct support_info_op {
+ ggml_tensor * op;
+
+ // true if the op runs on the accelerated ctx_clip.backend
+ bool is_accel = true;
+ };
+
+ struct support_info_graph {
+ // whether the clip_ctx.backend supports flash attention
+ bool fattn = true;
+ ggml_tensor * fattn_op = nullptr; // for debugging
+
+ std::vector<support_info_op> ops;
+ };
+
+ static void warmup(clip_ctx & ctx_clip) {
+ support_info_graph info;
+
+ if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
+ // try to enable flash attention to see if it's supported
+ ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
+ info = alloc_compute_meta(ctx_clip);
+ if (!info.fattn && info.fattn_op) {
+ auto op = info.fattn_op;
+ LOG_WRN("%s: *****************************************************************\n", __func__);
+ LOG_WRN("%s: WARNING: flash attention not supported by %s, memory usage will increase\n", __func__, ggml_backend_name(ctx_clip.backend));
+ LOG_WRN("%s: op params: \n", __func__);
+ static auto print_shape = [](const char * fn, const char * name, ggml_tensor * t) {
+ LOG_WRN("%s: %s: type = %s, ne = [%d %d %d %d], nb = [%d %d %d %d]\n", fn,
+ name, ggml_type_name(t->type),
+ t->ne[0], t->ne[1], t->ne[2], t->ne[3],
+ t->nb[0], t->nb[1], t->nb[2], t->nb[3]);
+ };
+ print_shape(__func__, " dst", op);
+ print_shape(__func__, "src0", op->src[0]);
+ print_shape(__func__, "src1", op->src[1]);
+ print_shape(__func__, "src2", op->src[2]);
+ LOG_WRN("%s: please report this on github as an issue\n", __func__);
+ LOG_WRN("%s: *****************************************************************\n", __func__);
+ ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
+ alloc_compute_meta(ctx_clip);
+ }
+ } else {
+ info = alloc_compute_meta(ctx_clip);
+ if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
+ LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
+ }
+ }
+
+ LOG_INF("%s: flash attention is %s\n", __func__,
+ (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
+
+ // print ops that are not supported by the GPU backend (if there is one)
+ if (ctx_clip.backend && ctx_clip.backend != ctx_clip.backend_cpu) {
+ std::vector<support_info_op> unsupported_ops;
+ for (const auto & op : info.ops) {
+ if (!op.is_accel) {
+ unsupported_ops.push_back(op);
+ }
+ }
+ if (!unsupported_ops.empty()) {
+ LOG_WRN("%s: *****************************************************************\n", __func__);
+ LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__);
+ LOG_WRN("%s: the performance will be suboptimal \n", __func__);
+ LOG_WRN("%s: list of unsupported ops (backend=%s):\n", __func__, ggml_backend_name(ctx_clip.backend));
+ for (const auto & op : unsupported_ops) {
+ LOG_WRN("%s: %16s: type = %s, ne = [%d %d %d %d]\n", __func__,
+ ggml_op_name(op.op->op),
+ ggml_type_name(op.op->type),
+ op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]);
+ }
+ LOG_WRN("%s: flash attention is %s\n", __func__,
+ (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
+ LOG_WRN("%s: please report this on github as an issue\n", __func__);
+ LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__);
+ LOG_WRN("%s: *****************************************************************\n", __func__);
+ }
+ }
+ }
+
+ static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) {
const auto & hparams = ctx_clip.model.hparams;
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
size / 1024.0 / 1024.0);
}
}
+
+ const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get());
+ const int n_nodes = ggml_graph_n_nodes(gf);
+
+ LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);
+
+ support_info_graph res {
+ /*.fattn = */ true,
+ /*.fattn_op = */ nullptr,
+ /*.ops = */ {},
+ };
+
+ // check op support
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
+ ggml_tensor * node = ggml_graph_node(gf, i);
+ res.ops.push_back({node, true});
+ if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
+ res.ops.back().is_accel = false;
+ if (node->op == GGML_OP_FLASH_ATTN_EXT) {
+ res.fattn = false;
+ res.fattn_op = node;
+ }
+ }
+ }
+
+ return res;
}
- void get_bool(const std::string & key, bool & output, bool required = true) {
+ void get_bool(const std::string & key, bool & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
- if (required) throw std::runtime_error("Key not found: " + key);
+ if (required) {
+ throw std::runtime_error("Key not found: " + key);
+ }
return;
}
output = gguf_get_val_bool(ctx_gguf.get(), i);
}
- void get_i32(const std::string & key, int & output, bool required = true) {
+ void get_i32(const std::string & key, int & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
- if (required) throw std::runtime_error("Key not found: " + key);
+ if (required) {
+ throw std::runtime_error("Key not found: " + key);
+ }
return;
}
output = gguf_get_val_i32(ctx_gguf.get(), i);
}
- void get_u32(const std::string & key, int & output, bool required = true) {
+ void get_u32(const std::string & key, int & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
- if (required) throw std::runtime_error("Key not found: " + key);
+ if (required) {
+ throw std::runtime_error("Key not found: " + key);
+ }
return;
}
output = gguf_get_val_u32(ctx_gguf.get(), i);
}
- void get_f32(const std::string & key, float & output, bool required = true) {
+ void get_f32(const std::string & key, float & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
- if (required) throw std::runtime_error("Key not found: " + key);
+ if (required) {
+ throw std::runtime_error("Key not found: " + key);
+ }
return;
}
output = gguf_get_val_f32(ctx_gguf.get(), i);
}
- void get_string(const std::string & key, std::string & output, bool required = true) {
+ void get_string(const std::string & key, std::string & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
- if (required) throw std::runtime_error("Key not found: " + key);
+ if (required) {
+ throw std::runtime_error("Key not found: " + key);
+ }
return;
}
output = std::string(gguf_get_val_str(ctx_gguf.get(), i));
}
- void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) {
+ void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
- if (required) throw std::runtime_error("Key not found: " + key);
+ if (required) {
+ throw std::runtime_error("Key not found: " + key);
+ }
return;
}
int n = gguf_get_arr_n(ctx_gguf.get(), i);
}
}
- void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
+ static void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
auto & hparams = model.hparams;
for (int x = 1; x <= max_patches_per_side; x++) {
for (int y = 1; y <= max_patches_per_side; y++) {
ctx_vision = new clip_ctx(ctx_params);
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
loader.load_tensors(*ctx_vision);
- loader.alloc_compute_meta(*ctx_vision);
+ loader.warmup(*ctx_vision);
}
if (loader.has_audio) {
ctx_audio = new clip_ctx(ctx_params);
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
loader.load_tensors(*ctx_audio);
- loader.alloc_compute_meta(*ctx_audio);
+ loader.warmup(*ctx_audio);
}
} catch (const std::exception & e) {
LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
- if (ctx_vision) {
- delete ctx_vision;
- }
- if (ctx_audio) {
- delete ctx_audio;
- }
+
+ delete ctx_vision;
+ delete ctx_audio;
+
return {nullptr, nullptr};
}
}
delete load_image_size;
}
-void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; }
-void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; }
-void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; }
-void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
+void clip_image_u8_free(struct clip_image_u8 * img) { delete img; }
+void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
+void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { delete batch; }
+void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { delete batch; }
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
return batch->entries.size();
#pragma once
#include "ggml.h"
+
#include <stddef.h>
#include <stdint.h>
CLIP_MODALITY_AUDIO,
};
+enum clip_flash_attn_type {
+ CLIP_FLASH_ATTN_TYPE_AUTO = -1,
+ CLIP_FLASH_ATTN_TYPE_DISABLED = 0,
+ CLIP_FLASH_ATTN_TYPE_ENABLED = 1,
+};
+
struct clip_context_params {
bool use_gpu;
enum ggml_log_level verbosity;
+ enum clip_flash_attn_type flash_attn_type;
};
struct clip_init_result {
mparams.print_timings = true;
mparams.n_threads = params.cpuparams.n_threads;
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
+ mparams.flash_attn_type = params.flash_attn_type;
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
if (!ctx_vision.get()) {
LOG_ERR("Failed to load vision model from %s\n", clip_path);
#include <cstdio>
#include <cstdlib>
#include <cstring>
-#include <limits>
#include <vector>
// represents raw image data, layout is RGBRGBRGB...
return "<__media__>";
}
+static clip_flash_attn_type mtmd_get_clip_flash_attn_type(enum llama_flash_attn_type flash_attn_type) {
+ switch (flash_attn_type) {
+ case LLAMA_FLASH_ATTN_TYPE_AUTO: return CLIP_FLASH_ATTN_TYPE_AUTO;
+ case LLAMA_FLASH_ATTN_TYPE_DISABLED: return CLIP_FLASH_ATTN_TYPE_DISABLED;
+ case LLAMA_FLASH_ATTN_TYPE_ENABLED: return CLIP_FLASH_ATTN_TYPE_ENABLED;
+ }
+ return CLIP_FLASH_ATTN_TYPE_AUTO;
+}
+
mtmd_context_params mtmd_context_params_default() {
mtmd_context_params params;
params.use_gpu = true;
params.verbosity = GGML_LOG_LEVEL_INFO;
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
params.media_marker = mtmd_default_marker();
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
return params;
}
clip_context_params ctx_clip_params;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
+ ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type);
auto res = clip_init(mmproj_fname, ctx_clip_params);
ctx_v = res.ctx_v;
ctx_a = res.ctx_a;
}
void mtmd_free(mtmd_context * ctx) {
- if (ctx) {
- delete ctx;
- }
+ delete ctx;
}
struct mtmd_tokenizer {
enum ggml_log_level verbosity;
const char * image_marker; // deprecated, use media_marker instead
const char * media_marker;
+ enum llama_flash_attn_type flash_attn_type;
};
MTMD_API const char * mtmd_default_marker(void);
mparams.print_timings = false;
mparams.n_threads = params_base.cpuparams.n_threads;
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
+ mparams.flash_attn_type = params_base.flash_attn_type;
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
if (mctx == nullptr) {
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());