From: Piotr Wilkin (ilintar) Date: Tue, 9 Sep 2025 04:05:55 +0000 (+0200) Subject: model-conversion : add extra debugging support for model conversion (#15877) X-Git-Tag: upstream/0.0.6527~103 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=acc1b008cfd95e63d5f99a370dbffb98e5a99d2c;p=pkg%2Fggml%2Fsources%2Fllama.cpp model-conversion : add extra debugging support for model conversion (#15877) * feat: Extra debugging support for model conversion - added BF16 support for llama-callback-eval and support for dumping intermediate steps in run-org-model.py --- diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index d4ef751f..cefa39a5 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -28,6 +28,15 @@ static std::string ggml_ne_string(const ggml_tensor * t) { return str; } +static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { + union { + float f; + uint32_t i; + } u; + u.i = (uint32_t)h.bits << 16; + return u.f; +} + static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; float v; @@ -43,6 +52,8 @@ static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * v = (float) *(int16_t *) &data[i]; } else if (type == GGML_TYPE_I8) { v = (float) *(int8_t *) &data[i]; + } else if (type == GGML_TYPE_BF16) { + v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]); } else { GGML_ABORT("fatal error"); } diff --git a/examples/model-conversion/requirements.txt b/examples/model-conversion/requirements.txt index b8148b26..ac9f69e1 100644 --- a/examples/model-conversion/requirements.txt +++ b/examples/model-conversion/requirements.txt @@ -1,5 +1,6 @@ --extra-index-url https://download.pytorch.org/whl/cpu -torch~=2.6.0 -torchvision~=0.21.0 -transformers~=4.55.0 -huggingface-hub~=0.34.0 +torch +torchvision +transformers +huggingface-hub +accelerate diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index f6188ea6..78a54abf 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -9,15 +9,134 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig import torch import numpy as np -unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') - -parser = argparse.ArgumentParser(description='Process model with specified path') -parser.add_argument('--model-path', '-m', help='Path to the model') +### If you want to dump RoPE activations, apply this monkey patch to the model +### class from Transformers that you are running (replace apertus.modeling_apertus +### with the proper package and class for your model +### === START ROPE DEBUG === +# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb + +# orig_rope = apply_rotary_pos_emb +# torch.set_printoptions(threshold=float('inf')) +# torch.set_printoptions(precision=6, sci_mode=False) + +# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +# # log inputs +# summarize(q, "RoPE.q_in") +# summarize(k, "RoPE.k_in") + +# # call original +# q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim) + +# # log outputs +# summarize(q_out, "RoPE.q_out") +# summarize(k_out, "RoPE.k_out") + +# return q_out, k_out + +# # Patch it +# import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402 +# apertus_mod.apply_rotary_pos_emb = debug_rope +### == END ROPE DEBUG === + + +def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3): + """ + Print a tensor in llama.cpp debug style. + + Supports: + - 2D tensors (seq, hidden) + - 3D tensors (batch, seq, hidden) + - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head + + Shows first and last max_vals of each vector per sequence position. + """ + t = tensor.detach().to(torch.float32).cpu() + + # Determine dimensions + if t.ndim == 3: + _, s, _ = t.shape + elif t.ndim == 2: + _, s = 1, t.shape[0] + t = t.unsqueeze(0) + elif t.ndim == 4: + _, s, _, _ = t.shape + else: + print(f"Skipping tensor due to unsupported dimensions: {t.ndim}") + return + + ten_shape = t.shape + + print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}") + print(" [") + print(" [") + + # Determine indices for first and last sequences + first_indices = list(range(min(s, max_seq))) + last_indices = list(range(max(0, s - max_seq), s)) + + # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq + has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s) + + # Combine indices + if has_overlap: + # If there's overlap, just use the combined unique indices + indices = sorted(list(set(first_indices + last_indices))) + separator_index = None + else: + # If no overlap, we'll add a separator between first and last sequences + indices = first_indices + last_indices + separator_index = len(first_indices) + + for i, si in enumerate(indices): + # Add separator if needed + if separator_index is not None and i == separator_index: + print(" ...") + + # Extract appropriate slice + vec = t[0, si] + if vec.ndim == 2: # 4D case: flatten heads × dim_per_head + flat = vec.flatten().tolist() + else: # 2D or 3D case + flat = vec.tolist() + + # First and last slices + first = flat[:max_vals] + last = flat[-max_vals:] if len(flat) >= max_vals else flat + first_str = ", ".join(f"{v:12.4f}" for v in first) + last_str = ", ".join(f"{v:12.4f}" for v in last) + + print(f" [{first_str}, ..., {last_str}]") + + print(" ],") + print(" ]") + print(f" sum = {t.sum().item():.6f}\n") + + +def debug_hook(name): + def fn(_m, input, output): + if isinstance(input, torch.Tensor): + summarize(input, name + "_in") + elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor): + summarize(input[0], name + "_in") + if isinstance(output, torch.Tensor): + summarize(output, name + "_out") + elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor): + summarize(output[0], name + "_out") + + return fn + + +unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME") + +parser = argparse.ArgumentParser(description="Process model with specified path") +parser.add_argument("--model-path", "-m", help="Path to the model") args = parser.parse_args() -model_path = os.environ.get('MODEL_PATH', args.model_path) +model_path = os.environ.get("MODEL_PATH", args.model_path) if model_path is None: - parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable") + parser.error( + "Model path must be specified either via --model-path argument or MODEL_PATH environment variable" + ) config = AutoConfig.from_pretrained(model_path) @@ -34,18 +153,30 @@ config = AutoConfig.from_pretrained(model_path) if unreleased_model_name: model_name_lower = unreleased_model_name.lower() - unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + unreleased_module_path = ( + f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + ) class_name = f"{unreleased_model_name}ForCausalLM" print(f"Importing unreleased model module: {unreleased_module_path}") try: - model_class = getattr(importlib.import_module(unreleased_module_path), class_name) - model = model_class.from_pretrained(model_path) # Note: from_pretrained, not fromPretrained + model_class = getattr( + importlib.import_module(unreleased_module_path), class_name + ) + model = model_class.from_pretrained( + model_path + ) # Note: from_pretrained, not fromPretrained except (ImportError, AttributeError) as e: print(f"Failed to import or load model: {e}") exit(1) else: - model = AutoModelForCausalLM.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, device_map="auto", offload_folder="offload" + ) + +for name, module in model.named_modules(): + if len(list(module.children())) == 0: # only leaf modules + module.register_forward_hook(debug_hook(name)) model_name = os.path.basename(model_path) # Printing the Model class to allow for easier debugging. This can be useful