]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert : remove bug in convert.py permute function (#3364)
authorZhang Peiyuan <redacted>
Wed, 27 Sep 2023 18:45:20 +0000 (02:45 +0800)
committerGitHub <redacted>
Wed, 27 Sep 2023 18:45:20 +0000 (20:45 +0200)
convert.py

index 4ac5030db61eb2ce1cdc99242f4a9100f43bb768..8bb6c7e4108523b2983d693cdd0e7949f4fafcd9 100755 (executable)
@@ -439,7 +439,7 @@ Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab'
 def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
     #print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) )
     if n_head_kv is not None and n_head != n_head_kv:
-        n_head //= n_head_kv
+        n_head = n_head_kv
     return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
                 .swapaxes(1, 2)
                 .reshape(weights.shape))