From: lhez Date: Fri, 21 Nov 2025 22:34:48 +0000 (-0800) Subject: opencl: refine condition for kqv mm (llama/17392) X-Git-Tag: upstream/0.9.4.395~136 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=774088bcc1c7207becf055038ac9602239ebeb7c;p=pkg%2Fggml%2Fsources%2Fggml opencl: refine condition for kqv mm (llama/17392) --- diff --git a/src/ggml-opencl/ggml-opencl.cpp b/src/ggml-opencl/ggml-opencl.cpp index 4cb6afe9..2319f7a9 100644 --- a/src/ggml-opencl/ggml-opencl.cpp +++ b/src/ggml-opencl/ggml-opencl.cpp @@ -6895,9 +6895,23 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co cl_context context = backend_ctx->context; if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){ - if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){ - ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); - return; + if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) { + // For KQ + if (ggml_is_permuted(src0) && ggml_is_permuted(src1) && + nb00 <= nb02 && + nb02 <= nb01 && + nb01 <= nb03 && + nb10 <= nb12 && + nb12 <= nb11 && + nb11 <= nb13) { + ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); + return; + } + // For KQV + if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); + return; + } } }