# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import re import shutil import sys from collections import OrderedDict, defaultdict from dataclasses import dataclass from typing import Any _DIGIT_RX = re.compile(r"(?<=\.)(\d+)(?=\.|$)") # numbers between dots or at the end def _pattern_of(key: str) -> str: """Replace every dot-delimited integer with '*' to get the structure.""" return _DIGIT_RX.sub("*", key) def _fmt_indices(values: list[int], cutoff=10) -> str: """Format a list of ints as single number, {a, ..., b}, or first...last.""" if len(values) == 1: return str(values[0]) values = sorted(values) if len(values) > cutoff: return f"{values[0]}...{values[-1]}" return ", ".join(map(str, values)) def update_key_name(mapping: dict[str, Any]) -> dict[str, Any]: """ Merge keys like 'layers.0.x', 'layers.1.x' into 'layers.{0, 1}.x' BUT only merge together keys that have the exact same value. Returns a new dict {merged_key: value}. """ # (pattern, value) -> list[set[int]] (per-star index values) not_mapping = False if not isinstance(mapping, dict): mapping = {k: k for k in mapping} not_mapping = True bucket: dict[tuple[str, Any], list[set[int]]] = defaultdict(list) for key, val in mapping.items(): digs = _DIGIT_RX.findall(key) patt = _pattern_of(key) for i, d in enumerate(digs): if len(bucket[patt]) <= i: bucket[patt].append(set()) bucket[patt][i].add(int(d)) bucket[patt].append(val) out_items = {} for patt, values in bucket.items(): sets, val = values[:-1], values[-1] parts = patt.split("*") # stars are between parts final = parts[0] for i in range(1, len(parts)): if i - 1 < len(sets) and sets[i - 1]: insert = _fmt_indices(sorted(sets[i - 1])) if len(sets[i - 1]) > 1: final += "{" + insert + "}" else: final += insert else: final += "*" final += parts[i] out_items[final] = val out = OrderedDict(out_items) if not_mapping: return out.keys() return out _ansi_re = re.compile(r"\x1b\[[0-9;]*m") def _strip_ansi(s: str) -> str: return _ansi_re.sub("", str(s)) def _pad(text, width): t = str(text) pad = max(0, width - len(_strip_ansi(t))) return t + " " * pad def _make_table(rows, headers): # compute display widths while ignoring ANSI codes cols = list(zip(*([headers] + rows))) if rows else [headers] widths = [max(len(_strip_ansi(x)) for x in col) for col in cols] header_line = " | ".join(_pad(h, w) for h, w in zip(headers, widths)) sep_line = "-+-".join("-" * w for w in widths) body = [" | ".join(_pad(c, w) for c, w in zip(r, widths)) for r in rows] return "\n".join([header_line, sep_line] + body) PALETTE = { "reset": "", "red": "", "yellow": "", "orange": "", "purple": "", "bold": "", "italic": "", "dim": "", } def _color(s, color): """Return color-formatted input `s` if `sys.stdout` is interactive, e.g. connected to a terminal.""" if sys.stdout.isatty(): return f"{PALETTE[color]}{s}{PALETTE['reset']}" else: return s def _get_terminal_width(default=80): try: return shutil.get_terminal_size().columns except Exception: return default @dataclass class LoadStateDictInfo: """ Mutable container for state-dict loading results and diagnostics. Each entry in this structure is mutable, and will usually be mutated in-place during the loading pipeline. Attributes: missing_keys (`set[str]`): Keys that are missing from the loaded checkpoints but expected in the model's architecture. unexpected_keys (`set[str]`): Keys that are found in the checkpoints, but not expected in the model's architecture. mismatched_keys (`set[tuple[str, tuple[int], tuple[int]]]`): Keys that are found in the checkpoints and are expected in the model's architecture, but with a different shape. error_msgs ( `list[str]`): Some potential error messages. conversion_errors (`dict[str, str]`): Errors happening during the on-the-fly weight conversion process. """ missing_keys: set[str] unexpected_keys: set[str] mismatched_keys: set[tuple[str, tuple[int], tuple[int]]] error_msgs: list[str] conversion_errors: dict[str, str] def missing_and_mismatched(self): """Return all effective missing keys, including `missing` and `mismatched` keys.""" return self.missing_keys | {k[0] for k in self.mismatched_keys} def to_dict(self): # Does not include the `conversion_errors` to be coherent with legacy reporting in the tests return { "missing_keys": self.missing_keys, "unexpected_keys": self.unexpected_keys, "mismatched_keys": self.mismatched_keys, "error_msgs": self.error_msgs, } def create_loading_report(self) -> str | None: """Generate the minimal table of a loading report.""" term_w = _get_terminal_width() rows = [] tips = "" if self.unexpected_keys: tips += ( f"\n- {_color('UNEXPECTED', 'orange') + PALETTE['italic']}\t:can be ignored when loading from different " "task/architecture; not ok if you expect identical arch." ) for k in update_key_name(self.unexpected_keys): status = _color("UNEXPECTED", "orange") rows.append([k, status, "", ""]) if self.missing_keys: tips += ( f"\n- {_color('MISSING', 'red') + PALETTE['italic']}\t:those params were newly initialized because missing " "from the checkpoint. Consider training on your downstream task." ) for k in update_key_name(self.missing_keys): status = _color("MISSING", "red") rows.append([k, status, ""]) if self.mismatched_keys: tips += ( f"\n- {_color('MISMATCH', 'yellow') + PALETTE['italic']}\t:ckpt weights were loaded, but they did not match " "the original empty weight shapes." ) iterator = {a: (b, c) for a, b, c in self.mismatched_keys} for key, (shape_ckpt, shape_model) in update_key_name(iterator).items(): status = _color("MISMATCH", "yellow") data = [ key, status, f"Reinit due to size mismatch - ckpt: {str(shape_ckpt)} vs model:{str(shape_model)}", ] rows.append(data) if self.conversion_errors: tips += f"\n- {_color('CONVERSION', 'purple') + PALETTE['italic']}\t:originate from the conversion scheme" for k, v in update_key_name(self.conversion_errors).items(): status = _color("CONVERSION", "purple") _details = f"\n\n{v}\n\n" rows.append([k, status, _details]) # If nothing is wrong, return None if len(rows) == 0: return None headers = ["Key", "Status"] if term_w > 200: headers += ["Details"] else: headers += ["", ""] table = _make_table(rows, headers=headers) tips = f"\n\n{PALETTE['italic']}Notes:{tips}{PALETTE['reset']}" report = table + tips return report def log_state_dict_report( model, pretrained_model_name_or_path: str, ignore_mismatched_sizes: bool, loading_info: LoadStateDictInfo, logger: logging.Logger | None = None, ): """ Log a readable report about state_dict loading issues. This version is terminal-size aware: for very small terminals it falls back to a compact Key | Status view so output doesn't wrap badly. """ if logger is None: logger = logging.getLogger(__name__) # Re-raise errors early if needed if loading_info.error_msgs: error_msg = "\n\t".join(loading_info.error_msgs) if "size mismatch" in error_msg: error_msg += ( "\n\tYou may consider adding `ignore_mismatched_sizes=True` to `from_pretrained(...)` if appropriate." ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") # Create the report table report = loading_info.create_loading_report() if report is None: return prelude = f"{PALETTE['bold']}{model.__class__.__name__} LOAD REPORT{PALETTE['reset']} from: {pretrained_model_name_or_path}\n" # Log the report as warning logger.warning(prelude + report) # Re-raise in those case, after the report if loading_info.conversion_errors: raise RuntimeError( "We encountered some issues during automatic conversion of the weights. For details look at the `CONVERSION` entries of " "the above report!" ) if not ignore_mismatched_sizes and loading_info.mismatched_keys: raise RuntimeError( "You set `ignore_mismatched_sizes` to `False`, thus raising an error. For details look at the above report!" )