# The scale is inverted
return data / scale.float()
- def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
+ def dequant_simple(weight: Tensor, scale: Tensor, block_size: Sequence[int] | None = None) -> Tensor:
scale = scale.float()
- if (weight_block_size := quant_config.get("weight_block_size")):
- # TODO: make sure it's a list of integers
- for i, size in enumerate(weight_block_size):
+ if block_size is not None:
+ for i, size in enumerate(block_size):
scale = scale.repeat_interleave(size, i)
- # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
- scale = scale[tuple(slice(0, size) for size in weight.shape)]
+ # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
+ scale = scale[tuple(slice(0, size) for size in weight.shape)]
return weight.float() * scale
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
+ def dequant_packed(w: Tensor, scale: Tensor, shape_tensor: Tensor, zero_point: Tensor | None, num_bits: int, group_size: int):
+ assert w.dtype == torch.int32
+ shape = tuple(shape_tensor.tolist())
+ assert len(shape) == 2
+ mask = (1 << num_bits) - 1
+
+ shifts = torch.arange(0, 32 - (num_bits - 1), num_bits, dtype=torch.int32)
+ if self.lazy:
+ shifts = LazyTorchTensor.from_eager(shifts)
+
+ if zero_point is None:
+ offset = 1 << (num_bits - 1)
+ else:
+ assert len(zero_point.shape) == 2
+ offset = (zero_point.unsqueeze(1) >> shifts.reshape(1, -1, 1)) & mask
+ offset = offset.reshape(-1, zero_point.shape[1])
+ # trim padding, and prepare for broadcast
+ # NOTE: the zero-point is packed along dim 0
+ offset = offset[:shape[0], :].unsqueeze(-1)
+
+ # extract values
+ # NOTE: the weights are packed along dim 1
+ unpacked = (w.unsqueeze(-1) >> shifts.reshape(1, 1, -1)) & mask
+ unpacked = unpacked.reshape(shape[0], -1)
+
+ # trim padding
+ unpacked = unpacked[:, :shape[1]]
+
+ # prepare for broadcast of the scale
+ unpacked = unpacked.reshape(shape[0], (unpacked.shape[-1] + group_size - 1) // group_size, group_size)
+ unpacked = unpacked - offset
+
+ return (unpacked * scale.unsqueeze(-1).float()).reshape(shape)
+
if quant_method == "bitnet":
for name in self.model_tensors.keys():
if name.endswith(".weight_scale"):
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
tensors_to_remove.append(name)
elif quant_method == "fp8":
+ block_size = quant_config.get("weight_block_size")
for name in self.model_tensors.keys():
if name.endswith(".weight_scale_inv"):
weight_name = name.removesuffix("_scale_inv")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
- self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
+ self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
tensors_to_remove.append(name)
elif quant_method == "gptq":
for name in self.model_tensors.keys():
".scales",
)
]
+ elif quant_method == "compressed-tensors":
+ quant_format = quant_config["format"]
+ groups = quant_config["config_groups"]
+ if len(groups) > 1:
+ raise NotImplementedError("Can't handle multiple config groups for compressed-tensors yet")
+ weight_config = tuple(groups.values())[0]["weights"]
+
+ if quant_format == "float-quantized" or quant_format == "int-quantized" or quant_format == "naive-quantized":
+ block_size = weight_config.get("block_structure", None)
+ strategy = weight_config.get("strategy")
+ assert strategy == "channel" or strategy == "block"
+ assert weight_config.get("group_size") is None # didn't find a model using this yet
+ for name in self.model_tensors.keys():
+ if name.endswith(".weight_scale"):
+ weight_name = name.removesuffix("_scale")
+ w = self.model_tensors[weight_name]
+ s = self.model_tensors[name]
+ self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), block_size)
+ tensors_to_remove.append(name)
+ elif quant_format == "pack-quantized":
+ assert weight_config.get("strategy") == "group"
+ assert weight_config.get("type", "int") == "int"
+ num_bits = weight_config.get("num_bits")
+ group_size = weight_config.get("group_size")
+ assert isinstance(num_bits, int)
+ assert isinstance(group_size, int)
+ for name in self.model_tensors.keys():
+ if name.endswith(".weight_packed"):
+ base_name = name.removesuffix("_packed")
+ w = self.model_tensors[name]
+ scale = self.model_tensors[base_name + "_scale"]
+ shape = self.model_tensors[base_name + "_shape"]
+ zero_point = self.model_tensors.get(base_name + "_zero_point", lambda: None)
+ new_tensors[base_name] = (
+ lambda w=w, scale=scale, shape=shape, zero_point=zero_point: dequant_packed(
+ w(), scale(), shape(), zero_point(), num_bits, group_size,
+ )
+ )
+ tensors_to_remove += [base_name + n for n in ("_packed", "_shape", "_scale")]
+ if (base_name + "_zero_point") in self.model_tensors:
+ tensors_to_remove.append(base_name + "_zero_point")
+ else:
+ raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
else:
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
# NOTE: doing this from a metaclass is very convenient
# TODO: make this even more comprehensive
for binary_op in (
- "lt", "le", "eq", "ne", "ge", "gt", "not"
- "abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul",
- "neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor",
+ "lt", "le", "eq", "ne", "ge", "gt",
+ "add", "and", "floordiv", "lshift", "mod", "mul", "matmul",
+ "or", "pow", "rshift", "sub", "truediv", "xor",
"iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
"radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
):
attr_name = f"__{binary_op}__"
+ # evaluation on the meta tensor is needed in case there's broadcasting
+ namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
+
+ for unary_op in ("not", "abs", "invert", "neg", "pos"):
+ attr_name = f"__{unary_op}__"
# the result of these operators usually has the same shape and dtype as the input,
# so evaluation on the meta tensor can be skipped.
namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)