]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model-conversion : add extra debugging support for model conversion (#15877)
authorPiotr Wilkin (ilintar) <redacted>
Tue, 9 Sep 2025 04:05:55 +0000 (06:05 +0200)
committerGitHub <redacted>
Tue, 9 Sep 2025 04:05:55 +0000 (06:05 +0200)
* feat: Extra debugging support for model conversion - added BF16 support for llama-callback-eval and support for dumping intermediate steps in run-org-model.py

examples/eval-callback/eval-callback.cpp
examples/model-conversion/requirements.txt
examples/model-conversion/scripts/causal/run-org-model.py

index d4ef751fbb63a44b81614a2612d019710bd00733..cefa39a57c886e1ac4d15d9bbedb743c72d969f8 100644 (file)
@@ -28,6 +28,15 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
     return str;
 }
 
+static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
+    union {
+        float f;
+        uint32_t i;
+    } u;
+    u.i = (uint32_t)h.bits << 16;
+    return u.f;
+}
+
 static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
     size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
     float v;
@@ -43,6 +52,8 @@ static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t *
         v = (float) *(int16_t *) &data[i];
     } else if (type == GGML_TYPE_I8) {
         v = (float) *(int8_t *) &data[i];
+    } else if (type == GGML_TYPE_BF16) {
+        v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]);
     } else {
         GGML_ABORT("fatal error");
     }
index b8148b269a249ed2740b857f05b57d8f73d959c3..ac9f69e10bcc9762111b980fc4dd8d5ba12fbd13 100644 (file)
@@ -1,5 +1,6 @@
 --extra-index-url https://download.pytorch.org/whl/cpu
-torch~=2.6.0
-torchvision~=0.21.0
-transformers~=4.55.0
-huggingface-hub~=0.34.0
+torch
+torchvision
+transformers
+huggingface-hub
+accelerate
index f6188ea6f338218b4be2a78d79a6d31c81bbae14..78a54abf13cdac6c628637602de6c89e4d2fc047 100755 (executable)
@@ -9,15 +9,134 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
 import torch
 import numpy as np
 
-unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
-
-parser = argparse.ArgumentParser(description='Process model with specified path')
-parser.add_argument('--model-path', '-m', help='Path to the model')
+### If you want to dump RoPE activations, apply this monkey patch to the model
+### class from Transformers that you are running (replace apertus.modeling_apertus
+### with the proper package and class for your model
+### === START ROPE DEBUG ===
+# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
+
+# orig_rope = apply_rotary_pos_emb
+# torch.set_printoptions(threshold=float('inf'))
+# torch.set_printoptions(precision=6, sci_mode=False)
+
+# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+#     # log inputs
+#     summarize(q, "RoPE.q_in")
+#     summarize(k, "RoPE.k_in")
+
+#     # call original
+#     q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
+
+#     # log outputs
+#     summarize(q_out, "RoPE.q_out")
+#     summarize(k_out, "RoPE.k_out")
+
+#     return q_out, k_out
+
+# # Patch it
+# import transformers.models.apertus.modeling_apertus as apertus_mod  # noqa: E402
+# apertus_mod.apply_rotary_pos_emb = debug_rope
+### == END ROPE DEBUG ===
+
+
+def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
+    """
+    Print a tensor in llama.cpp debug style.
+
+    Supports:
+    - 2D tensors (seq, hidden)
+    - 3D tensors (batch, seq, hidden)
+    - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
+
+    Shows first and last max_vals of each vector per sequence position.
+    """
+    t = tensor.detach().to(torch.float32).cpu()
+
+    # Determine dimensions
+    if t.ndim == 3:
+        _, s, _ = t.shape
+    elif t.ndim == 2:
+        _, s = 1, t.shape[0]
+        t = t.unsqueeze(0)
+    elif t.ndim == 4:
+        _, s, _, _ = t.shape
+    else:
+        print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
+        return
+
+    ten_shape = t.shape
+
+    print(f"ggml_debug: {name} = (f32)  ... = {{{ten_shape}}}")
+    print("                                     [")
+    print("                                      [")
+
+    # Determine indices for first and last sequences
+    first_indices = list(range(min(s, max_seq)))
+    last_indices = list(range(max(0, s - max_seq), s))
+
+    # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
+    has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
+
+    # Combine indices
+    if has_overlap:
+        # If there's overlap, just use the combined unique indices
+        indices = sorted(list(set(first_indices + last_indices)))
+        separator_index = None
+    else:
+        # If no overlap, we'll add a separator between first and last sequences
+        indices = first_indices + last_indices
+        separator_index = len(first_indices)
+
+    for i, si in enumerate(indices):
+        # Add separator if needed
+        if separator_index is not None and i == separator_index:
+            print("                                       ...")
+
+        # Extract appropriate slice
+        vec = t[0, si]
+        if vec.ndim == 2:  # 4D case: flatten heads × dim_per_head
+            flat = vec.flatten().tolist()
+        else:  # 2D or 3D case
+            flat = vec.tolist()
+
+        # First and last slices
+        first = flat[:max_vals]
+        last = flat[-max_vals:] if len(flat) >= max_vals else flat
+        first_str = ", ".join(f"{v:12.4f}" for v in first)
+        last_str = ", ".join(f"{v:12.4f}" for v in last)
+
+        print(f"                                       [{first_str}, ..., {last_str}]")
+
+    print("                                      ],")
+    print("                                     ]")
+    print(f"                                     sum = {t.sum().item():.6f}\n")
+
+
+def debug_hook(name):
+    def fn(_m, input, output):
+        if isinstance(input, torch.Tensor):
+            summarize(input, name + "_in")
+        elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor):
+            summarize(input[0], name + "_in")
+        if isinstance(output, torch.Tensor):
+            summarize(output, name + "_out")
+        elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor):
+            summarize(output[0], name + "_out")
+
+    return fn
+
+
+unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
+
+parser = argparse.ArgumentParser(description="Process model with specified path")
+parser.add_argument("--model-path", "-m", help="Path to the model")
 args = parser.parse_args()
 
-model_path = os.environ.get('MODEL_PATH', args.model_path)
+model_path = os.environ.get("MODEL_PATH", args.model_path)
 if model_path is None:
-    parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
+    parser.error(
+        "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
+    )
 
 config = AutoConfig.from_pretrained(model_path)
 
@@ -34,18 +153,30 @@ config = AutoConfig.from_pretrained(model_path)
 
 if unreleased_model_name:
     model_name_lower = unreleased_model_name.lower()
-    unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
+    unreleased_module_path = (
+        f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
+    )
     class_name = f"{unreleased_model_name}ForCausalLM"
     print(f"Importing unreleased model module: {unreleased_module_path}")
 
     try:
-        model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
-        model = model_class.from_pretrained(model_path)  # Note: from_pretrained, not fromPretrained
+        model_class = getattr(
+            importlib.import_module(unreleased_module_path), class_name
+        )
+        model = model_class.from_pretrained(
+            model_path
+        )  # Note: from_pretrained, not fromPretrained
     except (ImportError, AttributeError) as e:
         print(f"Failed to import or load model: {e}")
         exit(1)
 else:
-    model = AutoModelForCausalLM.from_pretrained(model_path)
+    model = AutoModelForCausalLM.from_pretrained(
+        model_path, device_map="auto", offload_folder="offload"
+    )
+
+for name, module in model.named_modules():
+    if len(list(module.children())) == 0:  # only leaf modules
+        module.register_forward_hook(debug_hook(name))
 
 model_name = os.path.basename(model_path)
 # Printing the Model class to allow for easier debugging. This can be useful