]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
gemma : fix bfloat16 -> float16 conversion issue (#5810)
authorkunal-vaishnavi <redacted>
Fri, 1 Mar 2024 14:08:08 +0000 (06:08 -0800)
committerGitHub <redacted>
Fri, 1 Mar 2024 14:08:08 +0000 (16:08 +0200)
convert-hf-to-gguf.py

index ae30b2a76971a91bd97f5dc90182ec73bed2f976..d3e8ec1f60c70ea9f8fa68a7597cd09bd33d7cc1 100755 (executable)
@@ -1811,16 +1811,15 @@ class GemmaModel(Model):
         tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
 
         for name, data_torch in self.get_tensors():
-            # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
-            if name.endswith("norm.weight"):
-                data_torch = data_torch + 1
-
             old_dtype = data_torch.dtype
 
             # convert any unsupported data types to float32
             if data_torch.dtype not in (torch.float16, torch.float32):
                 data_torch = data_torch.to(torch.float32)
 
+            # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
+            if name.endswith("norm.weight"):
+                data_torch = data_torch + 1
             data = data_torch.squeeze().numpy()
 
             # map tensor names