From: Georgi Gerganov Date: Mon, 23 Sep 2024 08:27:47 +0000 (+0300) Subject: metal : use F32 prec for K*Q in vec FA (#9595) X-Git-Tag: upstream/0.0.4488~681 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=bf9c1013ac40e5f1bd8e60b6d8bf16e0e8401445;p=pkg%2Fggml%2Fsources%2Fllama.cpp metal : use F32 prec for K*Q in vec FA (#9595) ggml-ci --- diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index f323ab5f..2b200032 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - half4 mq[D4]; + float4 mq[D4]; for (short ii = 0; ii < D4; ii += NW) { short i = ii + tiisg; - mq[i] = sq4[i]; + mq[i] = (float4) sq4[i]; } // pointer to the mask @@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; - half4x4 mk; - mk[0] = pk4[i + 0*(nb11/8)]; - mk[1] = pk4[i + 1*(nb11/8)]; - mk[2] = pk4[i + 2*(nb11/8)]; - mk[3] = pk4[i + 3*(nb11/8)]; + float4x4 mk; + mk[0] = (float4) pk4[i + 0*(nb11/8)]; + mk[1] = (float4) pk4[i + 1*(nb11/8)]; + mk[2] = (float4) pk4[i + 2*(nb11/8)]; + mk[3] = (float4) pk4[i + 3*(nb11/8)]; mqk += (float4) (mq[i] * mk); }