]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert : handle pre-quantized models (#14810)
authorcompilade <redacted>
Thu, 23 Oct 2025 20:31:41 +0000 (16:31 -0400)
committerGitHub <redacted>
Thu, 23 Oct 2025 20:31:41 +0000 (16:31 -0400)
* convert : begin handling pre-quantized models

* convert : fix conversion from FP8 for Deepseek-V3.1-Base

convert_hf_to_gguf.py

index 7b49969c02149b9e59c0e0692f392fe0eb5ef3bf..3e3db999c92edb3051a69633a5572f8e0a3573b2 100755 (executable)
@@ -90,10 +90,8 @@ class ModelBase:
     use_temp_file: bool
     lazy: bool
     dry_run: bool
-    part_names: list[str]
-    is_safetensors: bool
     hparams: dict[str, Any]
-    tensor_names: set[str] | None
+    model_tensors: dict[str, Callable[[], Tensor]]
     gguf_writer: gguf.GGUFWriter
     model_name: str | None
     metadata_override: Path | None
@@ -137,25 +135,8 @@ class ModelBase:
         self.dry_run = dry_run
         self.remote_hf_model_id = remote_hf_model_id
         self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
-        if remote_hf_model_id is not None:
-            self.is_safetensors = True
-
-            def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
-                logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
-                remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
-                self.tensor_names = set(name for name in remote_tensors.keys())
-                for name, remote_tensor in remote_tensors.items():
-                    yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
-
-            self.get_tensors = get_remote_tensors
-        else:
-            prefix = "model" if not self.is_mistral_format else "consolidated"
-            self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
-            self.is_safetensors = len(self.part_names) > 0
-            if not self.is_safetensors:
-                self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
         self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
-        self.tensor_names = None
+        self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
         self.metadata_override = metadata_override
         self.model_name = model_name
         self.dir_model_card = dir_model  # overridden in convert_lora_to_gguf.py
@@ -171,6 +152,8 @@ class ModelBase:
                 logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
                 self.ftype = gguf.LlamaFileType.MOSTLY_BF16
 
+        self.dequant_model()
+
         # Configure GGUF Writer
         self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
                                            split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@@ -192,67 +175,215 @@ class ModelBase:
             return None
         raise KeyError(f"could not find any of: {keys}")
 
-    def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
-        tensor_names_from_parts: set[str] = set()
+    def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
+        tensors: dict[str, Callable[[], Tensor]] = {}
+
+        if remote_hf_model_id is not None:
+            is_safetensors = True
+
+            logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
+            remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
+            for name, remote_tensor in remote_tensors.items():
+                tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
+
+            return tensors
+
+        prefix = "model" if not self.is_mistral_format else "consolidated"
+        part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
+        is_safetensors: bool = len(part_names) > 0
+        if not is_safetensors:
+            part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
+
+        tensor_names_from_index: set[str] = set()
 
         if not self.is_mistral_format:
-            index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
+            index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
             index_name += ".index.json"
             index_file = self.dir_model / index_name
 
             if index_file.is_file():
-                self.tensor_names = set()
                 logger.info(f"gguf: loading model weight map from '{index_name}'")
                 with open(index_file, "r", encoding="utf-8") as f:
                     index: dict[str, Any] = json.load(f)
                     weight_map = index.get("weight_map")
                     if weight_map is None or not isinstance(weight_map, dict):
                         raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
-                    self.tensor_names.update(weight_map.keys())
+                    tensor_names_from_index.update(weight_map.keys())
             else:
-                self.tensor_names = tensor_names_from_parts
                 weight_map = {}
         else:
-            self.tensor_names = tensor_names_from_parts
             weight_map = {}
 
-        for part_name in self.part_names:
-            logger.info(f"gguf: loading model part '{part_name}'")
+        for part_name in part_names:
+            logger.info(f"gguf: indexing model part '{part_name}'")
             ctx: ContextManager[Any]
-            if self.is_safetensors:
+            if is_safetensors:
                 from safetensors import safe_open
                 ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
             else:
                 ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
 
             with ctx as model_part:
-                tensor_names_from_parts.update(model_part.keys())
+                assert model_part is not None
 
                 for name in model_part.keys():
-                    if self.is_safetensors:
+                    if is_safetensors:
                         if self.lazy:
                             data = model_part.get_slice(name)
-                            data = LazyTorchTensor.from_safetensors_slice(data)
+                            data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data)  # noqa: E731
                         else:
                             data = model_part.get_tensor(name)
+                            data_gen = lambda data=data: data  # noqa: E731
                     else:
                         data = model_part[name]
                         if self.lazy:
-                            data = LazyTorchTensor.from_eager(data)
-                    yield name, data
+                            data_gen = lambda data=data: LazyTorchTensor.from_eager(data)  # noqa: E731
+                        else:
+                            data_gen = lambda data=data: data  # noqa: E731
+                    tensors[name] = data_gen
 
         # verify tensor name presence and identify potentially missing files
-        if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
-            missing = sorted(self.tensor_names.difference(tensor_names_from_parts))
-            extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
-            missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
-            if len(extra) == 0 and len(missing_files) > 0:
-                raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
-                                 f"Missing tensors: {missing}")
+        if len(tensor_names_from_index) > 0:
+            tensor_names_from_parts = set(tensors.keys())
+            if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0:
+                missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts))
+                extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index))
+                missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
+                if len(extra) == 0 and len(missing_files) > 0:
+                    raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
+                                     f"Missing tensors: {missing}")
+                else:
+                    raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
+                                     f"Missing tensors: {missing}\n"
+                                     f"Extra tensors: {extra}")
+
+        return tensors
+
+    def dequant_model(self):
+        tensors_to_remove: list[str] = []
+        new_tensors: dict[str, Callable[[], Tensor]] = {}
+
+        if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
+            quant_method = quant_config.get("quant_method")
+
+            def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
+                weight = weight.view(torch.uint8)
+                orig_shape = weight.shape
+
+                shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape)))))
+                data = weight.unsqueeze(0).expand((4, *orig_shape)) >> shift
+                data = data & 3
+                data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:]))
+
+                # The scale is inverted
+                return data / scale.float()
+
+            def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
+                scale = scale.float()
+
+                if (weight_block_size := quant_config.get("weight_block_size")):
+                    # TODO: make sure it's a list of integers
+                    for i, size in enumerate(weight_block_size):
+                        scale = scale.repeat_interleave(size, i)
+                # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
+                scale = scale[tuple(slice(0, size) for size in weight.shape)]
+
+                return weight.float() * scale
+
+            # ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
+            def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor:
+                bits = quant_config["bits"]
+                assert bits in (2, 3, 4, 8)
+                assert qweight.dtype == qzeros.dtype
+                maxq = (2 ** bits) - 1
+                weight = None
+                zeros = None
+                pack_dtype_bits = qweight.dtype.itemsize * 8
+
+                if bits in [2, 4, 8]:
+                    pack_factor = pack_dtype_bits // bits
+                    wf = torch.tensor(list(range(0, pack_dtype_bits, bits)), dtype=torch.int32).unsqueeze(0)
+                    if self.lazy:
+                        wf = LazyTorchTensor.from_eager(wf)
+
+                    zeros = torch.bitwise_right_shift(
+                        qzeros.unsqueeze(2).expand(-1, -1, pack_factor),
+                        wf.unsqueeze(0)
+                    ).to(torch.int16 if bits == 8 else torch.int8)
+                    zeros = torch.bitwise_and(zeros, maxq).reshape(scales.shape)
+
+                    weight = torch.bitwise_and(
+                        torch.bitwise_right_shift(
+                            qweight.unsqueeze(1).expand(-1, pack_factor, -1),
+                            wf.unsqueeze(-1)
+                        ).to(torch.int16 if bits == 8 else torch.int8),
+                        maxq
+                    )
+                elif bits == 3:
+                    raise NotImplementedError("3-bit gptq dequantization is not yet implemented")
+
+                assert weight is not None
+                assert zeros is not None
+
+                weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
+
+                # gptq_v2 doesn't need to offset zeros
+                if quant_config.get("checkpoint_format", "gptq") == "gptq":
+                    zeros += 1
+
+                return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
+
+            if quant_method == "bitnet":
+                for name in self.model_tensors.keys():
+                    if name.endswith(".weight_scale"):
+                        weight_name = name.removesuffix("_scale")
+                        w = self.model_tensors[weight_name]
+                        s = self.model_tensors[name]
+                        self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
+                        tensors_to_remove.append(name)
+            elif quant_method == "fp8":
+                for name in self.model_tensors.keys():
+                    if name.endswith(".weight_scale_inv"):
+                        weight_name = name.removesuffix("_scale_inv")
+                        w = self.model_tensors[weight_name]
+                        s = self.model_tensors[name]
+                        self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
+                        tensors_to_remove.append(name)
+            elif quant_method == "gptq":
+                for name in self.model_tensors.keys():
+                    if name.endswith(".qweight"):
+                        base_name = name.removesuffix(".qweight")
+                        g_idx = self.model_tensors[base_name + ".g_idx"]
+                        qweight = self.model_tensors[base_name + ".qweight"]
+                        qzeros = self.model_tensors[base_name + ".qzeros"]
+                        scales = self.model_tensors[base_name + ".scales"]
+                        new_tensors[base_name + ".weight"] = (
+                            lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
+                                g(), w(), z(), s()
+                            )
+                        )
+                        tensors_to_remove += [
+                            base_name + n
+                            for n in (
+                                ".g_idx",
+                                ".qzeros",
+                                ".qweight",
+                                ".scales",
+                            )
+                        ]
             else:
-                raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
-                                 f"Missing tensors: {missing}\n"
-                                 f"Extra tensors: {extra}")
+                raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
+
+        for name in tensors_to_remove:
+            if name in self.model_tensors:
+                del self.model_tensors[name]
+
+        for name, value in new_tensors.items():
+            self.model_tensors[name] = value
+
+    def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
+        for name, gen in self.model_tensors.items():
+            yield name, gen()
 
     def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
         if key not in gguf.MODEL_TENSORS[self.model_arch]:
@@ -4381,27 +4512,6 @@ class CodeShellModel(TextModel):
         self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
         self.gguf_writer.add_rope_scaling_factor(1.0)
 
-    _has_tok_embd = False
-
-    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
-        del bid  # unused
-
-        output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
-        tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
-
-        new_name = self.map_tensor_name(name)
-
-        # assuming token_embd.weight is seen before output.weight
-        if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
-            # even though the tensor file(s) does not contain the word embeddings they are still in the weight map
-            if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
-                logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
-                self.tensor_names.remove("transformer.wte.weight")
-        elif new_name == tok_embd_name:
-            self._has_tok_embd = True
-
-        return [(new_name, data_torch)]
-
 
 @ModelBase.register("InternLM2ForCausalLM")
 class InternLM2Model(TextModel):