]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: reject ops when a tensor is too large to allocate (#18646)
authorJeff Bolz <redacted>
Wed, 7 Jan 2026 11:03:32 +0000 (05:03 -0600)
committerGitHub <redacted>
Wed, 7 Jan 2026 11:03:32 +0000 (12:03 +0100)
ggml/src/ggml-vulkan/ggml-vulkan.cpp

index 1f255b705e09c5f9ab778004ae25db941a2eb85d..d68735a040adbc1d1dc83402bb304fc5b4af1951 100644 (file)
@@ -14305,6 +14305,19 @@ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const
 }
 
 static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+    const vk_device& device = ggml_vk_get_device(ctx->device);
+
+    // reject any tensors larger than the max buffer size
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) {
+            return false;
+        }
+    }
+    if (ggml_nbytes(op) > device->max_buffer_size) {
+        return false;
+    }
+
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
@@ -14353,8 +14366,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_MUL_MAT_ID:
             {
                 ggml_type src0_type = op->src[0]->type;
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                const vk_device& device = ggml_vk_get_device(ctx->device);
                 if (op->op == GGML_OP_MUL_MAT_ID) {
                     if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
                         // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
@@ -14415,8 +14426,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             }
         case GGML_OP_FLASH_ATTN_EXT:
             {
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                auto device = ggml_vk_get_device(ctx->device);
                 bool coopmat2 = device->coopmat2;
                 uint32_t HSK = op->src[1]->ne[0];
                 uint32_t HSV = op->src[2]->ne[0];
@@ -14638,8 +14647,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
                     return false;
                 }
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                auto device = ggml_vk_get_device(ctx->device);
                 // pipeline_argsort_large_f32 requires vulkan memory model.
                 if (device->vulkan_memory_model) {
                     return true;
@@ -14652,8 +14659,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
                     return false;
                 }
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                auto device = ggml_vk_get_device(ctx->device);
                 // We could potentially support larger, using argsort to sort the
                 // whole thing. Not clear if this is needed.
                 uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
@@ -14700,8 +14705,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_CUMSUM:
             {
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                auto device = ggml_vk_get_device(ctx->device);
                 if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
                     return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
                 }
@@ -14709,9 +14712,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             }
         case GGML_OP_SOLVE_TRI:
             {
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                const vk_device& device = ggml_vk_get_device(ctx->device);
-
                 if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
                     return false;
                 }
@@ -14776,9 +14776,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                     return false;
                 }
 
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                const vk_device& device = ggml_vk_get_device(ctx->device);
-
                 const uint32_t SPLIT_H = 16;
 
                 size_t stateC_size = SPLIT_H * d_state * sizeof(float);