]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commit
metal: minor q4 optimization and reduce code size (#2248)
authorShouzheng Liu <redacted>
Thu, 20 Jul 2023 10:32:22 +0000 (06:32 -0400)
committerGitHub <redacted>
Thu, 20 Jul 2023 10:32:22 +0000 (13:32 +0300)
commit417a85a0010519224cf154eb85d383ffeafeeead
treeeb9b9668426c7318e2ab1389f04118e126752a8e
parent294f424554c1599784ac9962462fc39ace92d8a5
metal: minor q4 optimization and reduce code size (#2248)

* metal: use uint16_t instead of uint8_t.

Apple GPU doesn't like uint8_t. For every operation on uint8_t
the gpu need to copy the uint8_t to an empty 16 bit register, then
it can issue other instructions.

For the matrix-vector multiplication kernel only, we observed a
340~350 GB/s memory read speed on M1 Max after this commit, which is
very close to the reported hardware limit.

* metal: update rms_norm kernel

This commit double the speed of rms_norm operations by using 512 threads
per threadgroup, combining with SIMD primitives to minimize the need for
thread group barriers.

* metal: use template to reduce size

Revert modifications on block_q4_0 and block_q4_1.
ggml-metal.m
ggml-metal.metal