]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
clip : use FA (#16837)
authorGeorgi Gerganov <redacted>
Sun, 2 Nov 2025 20:21:48 +0000 (22:21 +0200)
committerGitHub <redacted>
Sun, 2 Nov 2025 20:21:48 +0000 (21:21 +0100)
* clip : use FA

* cont : add warning about unsupported ops

* implement "auto" mode for clip flash attn

* clip : print more detailed op support info during warmup

* cont : remove obsolete comment [no ci]

* improve debugging message

* trailing space

* metal : remove stray return

---------

Co-authored-by: Xuan Son Nguyen <redacted>
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal.metal
tests/test-backend-ops.cpp
tools/mtmd/clip.cpp
tools/mtmd/clip.h
tools/mtmd/mtmd-cli.cpp
tools/mtmd/mtmd.cpp
tools/mtmd/mtmd.h
tools/server/server.cpp

index 360fbe19f0fb68131cc6725b039594c21b65a440..0cadd19a30fe96f1b86730d2459bebf78188a0c3 100644 (file)
@@ -707,6 +707,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
             if (op->src[0]->ne[0] != 32 &&
                 op->src[0]->ne[0] != 40 &&
                 op->src[0]->ne[0] != 64 &&
+                op->src[0]->ne[0] != 72 &&
                 op->src[0]->ne[0] != 80 &&
                 op->src[0]->ne[0] != 96 &&
                 op->src[0]->ne[0] != 112 &&
index fa839a1df6e304c8f85adb7e392e9fbd7c1d8d37..424c400f24b9bd084cb59cb59cd9991afe43dd9f 100644 (file)
@@ -5362,6 +5362,7 @@ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, hal
 template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  32,  32>;
 template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  40,  40>;
 template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  64,  64>;
+template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  72,  72>;
 template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  80,  80>;
 template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  96,  96>;
 template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  112, 112>;
@@ -5374,6 +5375,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]]  kernel flash_at
 template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  32,  32>;
 template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  40,  40>;
 template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  64,  64>;
+template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  72,  72>;
 template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  80,  80>;
 template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  96,  96>;
 template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  112, 112>;
@@ -5387,6 +5389,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]]  kernel flash_at
 template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 32,  32>;
 template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 40,  40>;
 template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 72,  72>;
 template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 80,  80>;
 template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 96,  96>;
 template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 112, 112>;
@@ -5400,6 +5403,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32,  32>;
 template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40,  40>;
 template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72,  72>;
 template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80,  80>;
 template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96,  96>;
 template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
@@ -5412,6 +5416,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32,  32>;
 template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40,  40>;
 template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72,  72>;
 template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80,  80>;
 template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96,  96>;
 template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
@@ -5424,6 +5429,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32,  32>;
 template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40,  40>;
 template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72,  72>;
 template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80,  80>;
 template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96,  96>;
 template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
@@ -5436,6 +5442,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32,  32>;
 template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40,  40>;
 template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72,  72>;
 template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80,  80>;
 template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96,  96>;
 template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
@@ -5448,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32,  32>;
 template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40,  40>;
 template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72,  72>;
 template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80,  80>;
 template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96,  96>;
 template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
index 04fa1b62d3b4d19a5d0779239ffa37f74d4cf509..2886bd37d680cf2d37e3d69d79efb0ef04408474 100644 (file)
@@ -7225,8 +7225,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
     }
 
-    for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) {
-        for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) {
+    for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
+        for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
             if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
             if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
             if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
index dcfdb49600b6c3399fd796347fd0477424211234..a7e1799e93d45701d2d78e5de02818fea169a13c 100644 (file)
@@ -6,7 +6,6 @@
 #include "clip-impl.h"
 #include "ggml.h"
 #include "ggml-cpp.h"
-#include "ggml-cpu.h"
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
 #include "gguf.h"
 #include <cstring>
 #include <fstream>
 #include <map>
-#include <regex>
 #include <stdexcept>
 #include <unordered_set>
 #include <vector>
-#include <sstream>
 #include <cinttypes>
 #include <limits>
 #include <array>
-#include <numeric>
 #include <functional>
 
+// TODO: allow to pass callback from user code
 struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
 
 enum ffn_op_type {
@@ -426,12 +423,14 @@ struct clip_ctx {
 
     int max_nodes = 8192;
     ggml_backend_sched_ptr sched;
+    clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
 
     // for debugging
     bool debug_graph = false;
     std::vector<ggml_tensor *> debug_print_tensors;
 
     clip_ctx(clip_context_params & ctx_params) {
+        flash_attn_type = ctx_params.flash_attn_type;
         debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
         backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
         if (!backend_cpu) {
@@ -2260,17 +2259,25 @@ private:
         ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
         //cb(k, "k", il);
 
-        ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
-        v = ggml_cont(ctx0, v);
-        //cb(k, "v", il);
-
         ggml_tensor * cur;
 
-        // TODO @ngxson : support flash attention
-        {
+        if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
+            ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
+
+            k = ggml_cast(ctx0, k, GGML_TYPE_F16);
+            v = ggml_cast(ctx0, v, GGML_TYPE_F16);
+
+            cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
+            ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
+
+            cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
+
+        } else {
+            ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
+            v = ggml_cont(ctx0, v);
+
             const auto n_tokens = q->ne[1];
             const auto n_head   = q->ne[2];
-            // const auto n_kv     = k->ne[1]; // for flash attention
 
             ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
             // F32 may not needed for vision encoders?
@@ -3192,7 +3199,87 @@ struct clip_model_loader {
         }
     }
 
-    void alloc_compute_meta(clip_ctx & ctx_clip) {
+    struct support_info_op {
+        ggml_tensor * op;
+
+        // true if the op runs on the accelerated ctx_clip.backend
+        bool is_accel = true;
+    };
+
+    struct support_info_graph {
+        // whether the clip_ctx.backend supports flash attention
+        bool fattn = true;
+        ggml_tensor * fattn_op = nullptr; // for debugging
+
+        std::vector<support_info_op> ops;
+    };
+
+    static void warmup(clip_ctx & ctx_clip) {
+        support_info_graph info;
+
+        if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
+            // try to enable flash attention to see if it's supported
+            ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
+            info = alloc_compute_meta(ctx_clip);
+            if (!info.fattn && info.fattn_op) {
+                auto op = info.fattn_op;
+                LOG_WRN("%s: *****************************************************************\n", __func__);
+                LOG_WRN("%s: WARNING: flash attention not supported by %s, memory usage will increase\n", __func__, ggml_backend_name(ctx_clip.backend));
+                LOG_WRN("%s: op params: \n", __func__);
+                static auto print_shape = [](const char * fn, const char * name, ggml_tensor * t) {
+                    LOG_WRN("%s:   %s: type = %s, ne = [%d %d %d %d], nb = [%d %d %d %d]\n", fn,
+                            name, ggml_type_name(t->type),
+                            t->ne[0], t->ne[1], t->ne[2], t->ne[3],
+                            t->nb[0], t->nb[1], t->nb[2], t->nb[3]);
+                };
+                print_shape(__func__, " dst", op);
+                print_shape(__func__, "src0", op->src[0]);
+                print_shape(__func__, "src1", op->src[1]);
+                print_shape(__func__, "src2", op->src[2]);
+                LOG_WRN("%s: please report this on github as an issue\n", __func__);
+                LOG_WRN("%s: *****************************************************************\n", __func__);
+                ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
+                alloc_compute_meta(ctx_clip);
+            }
+        } else {
+            info = alloc_compute_meta(ctx_clip);
+            if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
+                LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
+            }
+        }
+
+        LOG_INF("%s: flash attention is %s\n", __func__,
+            (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
+
+        // print ops that are not supported by the GPU backend (if there is one)
+        if (ctx_clip.backend && ctx_clip.backend != ctx_clip.backend_cpu) {
+            std::vector<support_info_op> unsupported_ops;
+            for (const auto & op : info.ops) {
+                if (!op.is_accel) {
+                    unsupported_ops.push_back(op);
+                }
+            }
+            if (!unsupported_ops.empty()) {
+                LOG_WRN("%s: *****************************************************************\n", __func__);
+                LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__);
+                LOG_WRN("%s:          the performance will be suboptimal                      \n", __func__);
+                LOG_WRN("%s:          list of unsupported ops (backend=%s):\n", __func__, ggml_backend_name(ctx_clip.backend));
+                for (const auto & op : unsupported_ops) {
+                    LOG_WRN("%s: %16s: type = %s, ne = [%d %d %d %d]\n", __func__,
+                            ggml_op_name(op.op->op),
+                            ggml_type_name(op.op->type),
+                            op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]);
+                }
+                LOG_WRN("%s: flash attention is %s\n", __func__,
+                    (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
+                LOG_WRN("%s: please report this on github as an issue\n", __func__);
+                LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__);
+                LOG_WRN("%s: *****************************************************************\n", __func__);
+            }
+        }
+    }
+
+    static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) {
         const auto & hparams = ctx_clip.model.hparams;
         ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
 
@@ -3223,57 +3310,95 @@ struct clip_model_loader {
                         size / 1024.0 / 1024.0);
             }
         }
+
+        const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get());
+        const int n_nodes  = ggml_graph_n_nodes(gf);
+
+        LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__,  n_splits, n_nodes);
+
+        support_info_graph res {
+            /*.fattn    = */ true,
+            /*.fattn_op = */ nullptr,
+            /*.ops      = */ {},
+        };
+
+        // check op support
+        for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
+            ggml_tensor * node = ggml_graph_node(gf, i);
+            res.ops.push_back({node, true});
+            if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
+                res.ops.back().is_accel = false;
+                if (node->op == GGML_OP_FLASH_ATTN_EXT) {
+                    res.fattn    = false;
+                    res.fattn_op = node;
+                }
+            }
+        }
+
+        return res;
     }
 
-    void get_bool(const std::string & key, bool & output, bool required = true) {
+    void get_bool(const std::string & key, bool & output, bool required = true) const {
         const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
         if (i < 0) {
-            if (required) throw std::runtime_error("Key not found: " + key);
+            if (required) {
+                throw std::runtime_error("Key not found: " + key);
+            }
             return;
         }
         output = gguf_get_val_bool(ctx_gguf.get(), i);
     }
 
-    void get_i32(const std::string & key, int & output, bool required = true) {
+    void get_i32(const std::string & key, int & output, bool required = true) const {
         const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
         if (i < 0) {
-            if (required) throw std::runtime_error("Key not found: " + key);
+            if (required) {
+                throw std::runtime_error("Key not found: " + key);
+            }
             return;
         }
         output = gguf_get_val_i32(ctx_gguf.get(), i);
     }
 
-    void get_u32(const std::string & key, int & output, bool required = true) {
+    void get_u32(const std::string & key, int & output, bool required = true) const {
         const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
         if (i < 0) {
-            if (required) throw std::runtime_error("Key not found: " + key);
+            if (required) {
+                throw std::runtime_error("Key not found: " + key);
+            }
             return;
         }
         output = gguf_get_val_u32(ctx_gguf.get(), i);
     }
 
-    void get_f32(const std::string & key, float & output, bool required = true) {
+    void get_f32(const std::string & key, float & output, bool required = true) const {
         const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
         if (i < 0) {
-            if (required) throw std::runtime_error("Key not found: " + key);
+            if (required) {
+                throw std::runtime_error("Key not found: " + key);
+            }
             return;
         }
         output = gguf_get_val_f32(ctx_gguf.get(), i);
     }
 
-    void get_string(const std::string & key, std::string & output, bool required = true) {
+    void get_string(const std::string & key, std::string & output, bool required = true) const {
         const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
         if (i < 0) {
-            if (required) throw std::runtime_error("Key not found: " + key);
+            if (required) {
+                throw std::runtime_error("Key not found: " + key);
+            }
             return;
         }
         output = std::string(gguf_get_val_str(ctx_gguf.get(), i));
     }
 
-    void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) {
+    void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) const {
         const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
         if (i < 0) {
-            if (required) throw std::runtime_error("Key not found: " + key);
+            if (required) {
+                throw std::runtime_error("Key not found: " + key);
+            }
             return;
         }
         int n = gguf_get_arr_n(ctx_gguf.get(), i);
@@ -3284,7 +3409,7 @@ struct clip_model_loader {
         }
     }
 
-    void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
+    static void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
         auto & hparams = model.hparams;
         for (int x = 1; x <= max_patches_per_side; x++) {
             for (int y = 1; y <= max_patches_per_side; y++) {
@@ -3312,24 +3437,22 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
             ctx_vision = new clip_ctx(ctx_params);
             loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
             loader.load_tensors(*ctx_vision);
-            loader.alloc_compute_meta(*ctx_vision);
+            loader.warmup(*ctx_vision);
         }
 
         if (loader.has_audio) {
             ctx_audio = new clip_ctx(ctx_params);
             loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
             loader.load_tensors(*ctx_audio);
-            loader.alloc_compute_meta(*ctx_audio);
+            loader.warmup(*ctx_audio);
         }
 
     } catch (const std::exception & e) {
         LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
-        if (ctx_vision) {
-            delete ctx_vision;
-        }
-        if (ctx_audio) {
-            delete ctx_audio;
-        }
+
+        delete ctx_vision;
+        delete ctx_audio;
+
         return {nullptr, nullptr};
     }
 
@@ -3367,10 +3490,10 @@ void clip_image_size_free(struct clip_image_size * load_image_size) {
     }
     delete load_image_size;
 }
-void clip_image_u8_free(struct clip_image_u8  * img) { if (img) delete img; }
-void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; }
-void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; }
-void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
+void clip_image_u8_free(struct clip_image_u8  * img) { delete img; }
+void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
+void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { delete batch; }
+void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { delete batch; }
 
 size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
     return batch->entries.size();
index 3387cdbd3695510066302d78a2e386a271f47803..6384e2adaf77535a60275886fb360882d6d243e9 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include "ggml.h"
+
 #include <stddef.h>
 #include <stdint.h>
 
@@ -22,9 +23,16 @@ enum clip_modality {
     CLIP_MODALITY_AUDIO,
 };
 
+enum clip_flash_attn_type {
+    CLIP_FLASH_ATTN_TYPE_AUTO     = -1,
+    CLIP_FLASH_ATTN_TYPE_DISABLED = 0,
+    CLIP_FLASH_ATTN_TYPE_ENABLED  = 1,
+};
+
 struct clip_context_params {
     bool use_gpu;
     enum ggml_log_level verbosity;
+    enum clip_flash_attn_type flash_attn_type;
 };
 
 struct clip_init_result {
index fd1fb6581b1634dd560d56a9224303a893448ca7..17aea1472b3c6b2826ec6a137f7a1ca905b899ab 100644 (file)
@@ -136,6 +136,7 @@ struct mtmd_cli_context {
         mparams.print_timings = true;
         mparams.n_threads = params.cpuparams.n_threads;
         mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
+        mparams.flash_attn_type = params.flash_attn_type;
         ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
         if (!ctx_vision.get()) {
             LOG_ERR("Failed to load vision model from %s\n", clip_path);
index 196641dd95ef4cd8841130c03d7eeaebab115b35..297eef437ab912e1bf0cdb22f537a337949364bc 100644 (file)
@@ -19,7 +19,6 @@
 #include <cstdio>
 #include <cstdlib>
 #include <cstring>
-#include <limits>
 #include <vector>
 
 // represents raw image data, layout is RGBRGBRGB...
@@ -92,6 +91,15 @@ const char * mtmd_default_marker() {
     return "<__media__>";
 }
 
+static clip_flash_attn_type mtmd_get_clip_flash_attn_type(enum llama_flash_attn_type flash_attn_type) {
+    switch (flash_attn_type) {
+        case LLAMA_FLASH_ATTN_TYPE_AUTO:     return CLIP_FLASH_ATTN_TYPE_AUTO;
+        case LLAMA_FLASH_ATTN_TYPE_DISABLED: return CLIP_FLASH_ATTN_TYPE_DISABLED;
+        case LLAMA_FLASH_ATTN_TYPE_ENABLED:  return CLIP_FLASH_ATTN_TYPE_ENABLED;
+    }
+    return CLIP_FLASH_ATTN_TYPE_AUTO;
+}
+
 mtmd_context_params mtmd_context_params_default() {
     mtmd_context_params params;
     params.use_gpu = true;
@@ -100,6 +108,7 @@ mtmd_context_params mtmd_context_params_default() {
     params.verbosity = GGML_LOG_LEVEL_INFO;
     params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
     params.media_marker = mtmd_default_marker();
+    params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
     return params;
 }
 
@@ -164,6 +173,7 @@ struct mtmd_context {
         clip_context_params ctx_clip_params;
         ctx_clip_params.use_gpu   = ctx_params.use_gpu;
         ctx_clip_params.verbosity = ctx_params.verbosity;
+        ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type);
         auto res = clip_init(mmproj_fname, ctx_clip_params);
         ctx_v = res.ctx_v;
         ctx_a = res.ctx_a;
@@ -378,9 +388,7 @@ mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
 }
 
 void mtmd_free(mtmd_context * ctx) {
-    if (ctx) {
-        delete ctx;
-    }
+    delete ctx;
 }
 
 struct mtmd_tokenizer {
index 0b5d2ba0c763418fb0d51b6291b51cc46c8b0ca7..4ae1925bcdfb64775ababd8f32d0a85a219b5724 100644 (file)
@@ -82,6 +82,7 @@ struct mtmd_context_params {
     enum ggml_log_level verbosity;
     const char * image_marker; // deprecated, use media_marker instead
     const char * media_marker;
+    enum llama_flash_attn_type flash_attn_type;
 };
 
 MTMD_API const char * mtmd_default_marker(void);
index aa4981585200adb7bad00ae1db54ddbd72a735fc..a9bef35189b3ab877b82153366fccf51e5e2885c 100644 (file)
@@ -2456,6 +2456,7 @@ struct server_context {
             mparams.print_timings = false;
             mparams.n_threads     = params_base.cpuparams.n_threads;
             mparams.verbosity     = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
+            mparams.flash_attn_type = params_base.flash_attn_type;
             mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
             if (mctx == nullptr) {
                 SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());