From: Georgi Gerganov Date: Mon, 23 Sep 2024 08:27:47 +0000 (+0300) Subject: metal : use F32 prec for K*Q in vec FA (llama/9595) X-Git-Tag: upstream/0.0.1642~347 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=64f30f362597c427889446484b8f6a9176f14601;p=pkg%2Fggml%2Fsources%2Fggml metal : use F32 prec for K*Q in vec FA (llama/9595) ggml-ci --- diff --git a/src/ggml-metal.metal b/src/ggml-metal.metal index f323ab5f..2b200032 100644 --- a/src/ggml-metal.metal +++ b/src/ggml-metal.metal @@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - half4 mq[D4]; + float4 mq[D4]; for (short ii = 0; ii < D4; ii += NW) { short i = ii + tiisg; - mq[i] = sq4[i]; + mq[i] = (float4) sq4[i]; } // pointer to the mask @@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; - half4x4 mk; - mk[0] = pk4[i + 0*(nb11/8)]; - mk[1] = pk4[i + 1*(nb11/8)]; - mk[2] = pk4[i + 2*(nb11/8)]; - mk[3] = pk4[i + 3*(nb11/8)]; + float4x4 mk; + mk[0] = (float4) pk4[i + 0*(nb11/8)]; + mk[1] = (float4) pk4[i + 1*(nb11/8)]; + mk[2] = (float4) pk4[i + 2*(nb11/8)]; + mk[3] = (float4) pk4[i + 3*(nb11/8)]; mqk += (float4) (mq[i] * mk); }