From: Georgi Gerganov Date: Thu, 14 Mar 2024 11:32:14 +0000 (+0200) Subject: gguf-py : fix dtype check (#6045) X-Git-Tag: upstream/0.0.4488~2061 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=77178eedc83d49f31bf757d8e12315d76460be78;p=pkg%2Fggml%2Fsources%2Fllama.cpp gguf-py : fix dtype check (#6045) --- diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 9c1eeac3..4d389be9 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -204,7 +204,7 @@ class GGUFWriter: for i in range(n_dims): self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) if raw_dtype is None: - if tensor_shape == np.float32: + if tensor_dtype == np.float32: dtype = GGMLQuantizationType.F32 elif tensor_dtype == np.float16: dtype = GGMLQuantizationType.F16