From: Jeff Bolz Date: Wed, 27 Nov 2024 07:08:54 +0000 (-0600) Subject: vulkan: skip integer div/mod in get_offsets for batch_idx==0 (llama/10506) X-Git-Tag: upstream/0.0.1642~130 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=bc7255385b526d42e1c31893d0ec2b10c805b9a4;p=pkg%2Fggml%2Fsources%2Fggml vulkan: skip integer div/mod in get_offsets for batch_idx==0 (llama/10506) --- diff --git a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp index 8d0a5791..2ec1af5c 100644 --- a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +++ b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp @@ -52,13 +52,16 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #endif #ifndef MUL_MAT_ID - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; + uint batch_idx_a = 0; + if (batch_idx != 0) { + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; - const uint batch_idx_a = i03 * p.ne02 + i02; + batch_idx_a = i03 * p.ne02 + i02; + } #else const uint expert_id = data_ids[expert_idx]; #endif