]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: 64-bit im2col (llama/16135)
authorJeff Bolz <redacted>
Sun, 28 Sep 2025 06:38:37 +0000 (01:38 -0500)
committerGeorgi Gerganov <redacted>
Mon, 29 Sep 2025 12:18:12 +0000 (15:18 +0300)
* vulkan: 64-bit im2col

Add variants of the im2col shaders that use buffer_device_address/buffer_reference,
and use 64-bit address calculations. This is needed for large convolutions used in
stable-diffusion.cpp.

* fix validation error for large im2col

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp
ggml/src/ggml-vulkan/vulkan-shaders/types.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 9c17ad95efa73d2eec2ea91b389ab1c03b17d545..2608cbd06892c8369a796dc36a10c7a763813c2b 100644 (file)
@@ -408,6 +408,8 @@ struct vk_device_struct {
     bool subgroup_ballot;
     bool subgroup_clustered;
     bool multi_add;
+    bool shader_int64;
+    bool buffer_device_address;
 
     bool add_rms_fusion;
     uint32_t partials_binding_alignment;
@@ -655,6 +657,7 @@ struct vk_buffer_struct {
     vk::MemoryPropertyFlags memory_property_flags;
     void * ptr;
     size_t size = 0;
+    vk::DeviceAddress bda_addr {};
 
     vk_device device;
 
@@ -987,6 +990,7 @@ struct vk_op_argsort_push_constants {
 };
 
 struct vk_op_im2col_push_constants {
+    uint64_t dst_addr;
     uint32_t batch_offset; uint32_t offset_delta;
     uint32_t IC;
     uint32_t IW; uint32_t IH;
@@ -1000,6 +1004,7 @@ struct vk_op_im2col_push_constants {
 };
 
 struct vk_op_im2col_3d_push_constants {
+    uint64_t dst_addr;
     uint32_t nb10;
     uint32_t nb11;
     uint32_t nb12;
@@ -2012,10 +2017,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
         return buf;
     }
 
+    vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
+    vk::MemoryAllocateFlags mem_flags {};
+    if (device->buffer_device_address) {
+        usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
+        mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
+    }
+
     vk::BufferCreateInfo buffer_create_info{
         vk::BufferCreateFlags(),
         size,
-        vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst,
+        usage_flags,
         vk::SharingMode::eExclusive,
         0,
         nullptr,
@@ -2027,6 +2039,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
 
     vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
 
+    const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
+
     for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
         const auto & req_flags = *it;
 
@@ -2038,7 +2052,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
         buf->memory_property_flags = req_flags;
 
         try {
-            buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
+            buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
             break;
         } catch (const vk::SystemError& e) {
             // loop and retry
@@ -2066,6 +2080,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
     buf->device = device;
     buf->size = size;
 
+    if (device->buffer_device_address) {
+        const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
+        buf->bda_addr = device->device.getBufferAddress(addressInfo);
+    }
+
 #ifdef GGML_VULKAN_MEMORY_DEBUG
     device->memory_logger->log_allocation(buf, size);
 #endif
@@ -3532,14 +3551,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
-    ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
-    if (device->float_controls_rte_fp16) {
-        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
+#define IM2COL(bda) \
+    ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);   \
+    ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);      \
+    if (device->float_controls_rte_fp16) {  \
+        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);   \
+        ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);      \
+    } else {    \
+        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);   \
+        ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);      \
+    }
+    if (device->shader_int64 && device->buffer_device_address) {
+        IM2COL(_bda)
     } else {
-        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
+        IM2COL()
     }
 
     ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@@ -4017,6 +4042,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
                             device->vendor_id != VK_VENDOR_ID_INTEL &&
                             getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
 
+        device->shader_int64 = device_features2.features.shaderInt64;
+        device->buffer_device_address = vk12_features.bufferDeviceAddress;
+
         if (device->subgroup_size_control) {
             device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
             device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
@@ -8635,6 +8663,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
 
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
     } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
+        if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
+            // buffer device address path doesn't use dst buffer
+            d_sz = 1;
+        }
         // im2col uses only src1 and dst buffers
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
     } else if (op == GGML_OP_COUNT_EQUAL) {
@@ -9486,7 +9518,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
 
     const uint32_t pelements = OW * KW * KH;
 
+    const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
+    const vk_buffer d_buf = d_buf_ctx->dev_buffer;
+
+    const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
+
     ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
+        dst_addr,
         batch_offset, offset_delta,
         IC, IW, IH, OW, OH, KW, KH,
         pelements,
@@ -9522,8 +9560,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
     const int64_t OH = ne2;
     const int64_t OW = ne1;
 
+    const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
+    const vk_buffer d_buf = d_buf_ctx->dev_buffer;
+
+    const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
+
     vk_op_im2col_3d_push_constants pc {};
 
+    pc.dst_addr = dst_addr;
     pc.nb10 = nb10 / ggml_type_size(src1->type);
     pc.nb11 = nb11 / ggml_type_size(src1->type);
     pc.nb12 = nb12 / ggml_type_size(src1->type);
index fdbcf7eba0fa588626499213b6947211aa9af6d2..f0f19a019ca26713708de3c26d77a8ffab778b0e 100644 (file)
@@ -5,8 +5,11 @@
 
 #include "rte.comp"
 
+#include "types.comp"
+
 layout (push_constant) uniform parameter
 {
+    BDA_STORAGE_T dst_addr;
     uint batch_offset; uint offset_delta;
     uint IC;
     uint IW; uint IH;
@@ -19,8 +22,6 @@ layout (push_constant) uniform parameter
     int d0; int d1;
 } p;
 
-#include "types.comp"
-
 layout(constant_id = 0) const uint BLOCK_SIZE = 32;
 
 const uint NUM_ITER = 512 / BLOCK_SIZE;
@@ -30,6 +31,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
 layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 
+#if BDA
+layout (buffer_reference) buffer D_ptr {D_TYPE d;};
+#endif
+
 void main() {
     const uint gidx = gl_GlobalInvocationID.x;
 
@@ -38,7 +43,7 @@ void main() {
     const uint ic = gl_GlobalInvocationID.z % p.IC;
 
     const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
-    const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
+    const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
     const int oh_s1 = int(oh) * p.s1;
     const uint ksize = p.OW * p.KH;
 
@@ -50,7 +55,7 @@ void main() {
     uint current_ix = rem % p.OW;
 
     A_TYPE values[NUM_ITER];
-    uint offset_dst[NUM_ITER];
+    BDA_OFFSET_T offset_dst[NUM_ITER];
     [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
         values[idx] = A_TYPE(0);
     }
@@ -66,7 +71,7 @@ void main() {
         const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
         const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
 
-        offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
+        offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
 
         if ((iih < p.IH) && (iiw < p.IW)) {
             values[idx] = data_a[src_base + iih * p.IW + iiw];
@@ -89,7 +94,11 @@ void main() {
             continue;
         }
 
+#if BDA
+        D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
+        dst_addr.d = D_TYPE(values[idx]);
+#else
         data_d[offset_dst[idx]] = D_TYPE(values[idx]);
+#endif
     }
-
 }
index 3b010bdeb5eb8fc00509c66952e2d237ad696c7f..9faa636ac251a576fe180deaa6bbd30e96de1e55 100644 (file)
@@ -6,8 +6,11 @@
 
 #include "rte.comp"
 
+#include "types.comp"
+
 layout (push_constant) uniform parameter
 {
+    BDA_STORAGE_T dst_addr;
     uint32_t nb10;
     uint32_t nb11;
     uint32_t nb12;
@@ -38,8 +41,6 @@ layout (push_constant) uniform parameter
     uint32_t misalign_offsets;
 } p;
 
-#include "types.comp"
-
 uint get_aoffset() { return p.misalign_offsets >> 16; }
 uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
 
@@ -50,6 +51,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
 layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 
+#if BDA
+layout (buffer_reference) buffer D_ptr {D_TYPE d;};
+#endif
+
 void main() {
     const uint32_t i = gl_GlobalInvocationID.x;
 
@@ -100,13 +105,22 @@ void main() {
         const uint32_t iih = ioh * s1 + ikh * d1 - p1;
         const uint32_t iid = iod * s2 + ikd * d2 - p2;
 
-        const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
+        const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
 
+        const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
+#if BDA
+        D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst);
+        if (iih >= IH || iiw >= IW || iid >= ID) {
+            dst_addr.d = D_TYPE(0.0f);
+        } else {
+            dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]);
+        }
+#else
         if (iih >= IH || iiw >= IW || iid >= ID) {
             data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
         } else {
-            const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
             data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
         }
+#endif
     }
 }
index 5032ed1738c0942645f69e594651882dbf148c8b..2fa54ce51fc83f156fe0895f0d1b7bc0717366b5 100644 (file)
@@ -1447,4 +1447,19 @@ float e8m0_to_fp32(uint8_t x) {
     return uintBitsToFloat(bits);
 }
 
+#if BDA
+
+#extension GL_EXT_buffer_reference : enable
+#extension GL_EXT_shader_explicit_arithmetic_types_int64 : enable
+
+#define BDA_STORAGE_T uint64_t
+#define BDA_OFFSET_T uint64_t
+
+#else
+
+#define BDA_STORAGE_T uvec2
+#define BDA_OFFSET_T uint
+
+#endif
+
 #endif // !defined(GGML_TYPES_COMP)
index d2591e26b6c17863282bae0cf503f5336eb5386e..84bb9df9a05594d0253c57e4956cf9f6e690df4a 100644 (file)
@@ -775,13 +775,15 @@ void process_shaders() {
     string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
 
-    string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-    string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
-    string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
-
-    string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-    string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
-    string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
+    for (std::string dim_str : {"", "_3d"}) {
+        for (bool bda : {false, true}) {
+            std::string bda_str = bda ? "_bda" : "";
+            std::string bda_def = bda ? "1" : "0";
+            string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
+            string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
+            string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
+        }
+    }
 
     string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));