print(tensor_name)
-def print_tensor_info(model_path: Path, tensor_name: str):
+def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
tensor_file = find_tensor_file(model_path, tensor_name)
if tensor_file is None:
print(f"Tensor: {tensor_name}")
print(f"File: {tensor_file}")
print(f"Shape: {shape}")
+ if num_values is not None:
+ tensor = f.get_tensor(tensor_name)
+ print(f"Dtype: {tensor.dtype}")
+ flat = tensor.flatten()
+ n = min(num_values, flat.numel())
+ print(f"Values: {flat[:n].tolist()}")
else:
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
sys.exit(1)
action="store_true",
help="List unique tensor patterns in the model (layer numbers replaced with #)"
)
+ parser.add_argument(
+ "-n", "--num-values",
+ nargs="?",
+ const=10,
+ default=None,
+ type=int,
+ metavar="N",
+ help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
+ )
args = parser.parse_args()
if args.tensor_name is None:
print("Error: tensor_name is required when not using --list")
sys.exit(1)
- print_tensor_info(model_path, args.tensor_name)
+ print_tensor_info(model_path, args.tensor_name, args.num_values)
if __name__ == "__main__":