From: kunal-vaishnavi Date: Fri, 1 Mar 2024 14:08:08 +0000 (-0800) Subject: gemma : fix bfloat16 -> float16 conversion issue (#5810) X-Git-Tag: upstream/0.0.4488~2182 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=e7433867288d2f142cffe596f3751bda5d7ee2c7;p=pkg%2Fggml%2Fsources%2Fllama.cpp gemma : fix bfloat16 -> float16 conversion issue (#5810) --- diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index ae30b2a7..d3e8ec1f 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1811,16 +1811,15 @@ class GemmaModel(Model): tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) for name, data_torch in self.get_tensors(): - # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89 - if name.endswith("norm.weight"): - data_torch = data_torch + 1 - old_dtype = data_torch.dtype # convert any unsupported data types to float32 if data_torch.dtype not in (torch.float16, torch.float32): data_torch = data_torch.to(torch.float32) + # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89 + if name.endswith("norm.weight"): + data_torch = data_torch + 1 data = data_torch.squeeze().numpy() # map tensor names