]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
gguf-py : fix dtype check (#6045)
authorGeorgi Gerganov <redacted>
Thu, 14 Mar 2024 11:32:14 +0000 (13:32 +0200)
committerGeorgi Gerganov <redacted>
Thu, 14 Mar 2024 11:32:14 +0000 (13:32 +0200)
gguf-py/gguf/gguf_writer.py

index 9c1eeac318c7dc4ae4df6527825d7f279875e98f..4d389be951d721031b5357002d5dc4d72a6c761d 100644 (file)
@@ -204,7 +204,7 @@ class GGUFWriter:
         for i in range(n_dims):
             self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
         if raw_dtype is None:
-            if tensor_shape == np.float32:
+            if tensor_dtype == np.float32:
                 dtype = GGMLQuantizationType.F32
             elif tensor_dtype == np.float16:
                 dtype = GGMLQuantizationType.F16