]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: Fix validation failure in quantized flash attention (llama/16292)
authorJeff Bolz <redacted>
Mon, 29 Sep 2025 04:50:37 +0000 (23:50 -0500)
committerGeorgi Gerganov <redacted>
Mon, 29 Sep 2025 12:18:12 +0000 (15:18 +0300)
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

index e80eff27815ae36323e769f215872076eb8b76b7..9b1f153bf7f19190e42578557230fa366ae89c6b 100644 (file)
@@ -67,30 +67,48 @@ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
 #if defined(A_TYPE_PACKED16)
 #define BINDING_IDX_K 0
 #define BINDING_IDX_V 1
-layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
+layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
+layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
 #endif
 
 #if defined(DATA_A_Q4_0)
 #define BLOCK_BYTE_SIZE 18
 
 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
-    uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
-    uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
-    uint shift = (iqs & 0x10) >> 2;
-    vui_lo >>= shift;
-    vui_hi >>= shift;
-
-    return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+    if (binding_idx == BINDING_IDX_K) {
+        uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
+        uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
+        uint shift = (iqs & 0x10) >> 2;
+        vui_lo >>= shift;
+        vui_hi >>= shift;
+
+        return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+    } else {
+        uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
+        uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
+        uint shift = (iqs & 0x10) >> 2;
+        vui_lo >>= shift;
+        vui_hi >>= shift;
+
+        return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+    }
 }
 #endif
 
 #if defined(DATA_A_Q8_0)
 #define BLOCK_BYTE_SIZE 34
 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
-    const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
-    const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
+    if (binding_idx == BINDING_IDX_K) {
+        const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
+        const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
 
-    return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+        return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+    } else {
+        const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
+        const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
+
+        return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+    }
 }
 #endif