You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

281 lines
9.8 KiB

4 days ago
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
import shutil
import sys
from collections import OrderedDict, defaultdict
from dataclasses import dataclass
from typing import Any
_DIGIT_RX = re.compile(r"(?<=\.)(\d+)(?=\.|$)") # numbers between dots or at the end
def _pattern_of(key: str) -> str:
"""Replace every dot-delimited integer with '*' to get the structure."""
return _DIGIT_RX.sub("*", key)
def _fmt_indices(values: list[int], cutoff=10) -> str:
"""Format a list of ints as single number, {a, ..., b}, or first...last."""
if len(values) == 1:
return str(values[0])
values = sorted(values)
if len(values) > cutoff:
return f"{values[0]}...{values[-1]}"
return ", ".join(map(str, values))
def update_key_name(mapping: dict[str, Any]) -> dict[str, Any]:
"""
Merge keys like 'layers.0.x', 'layers.1.x' into 'layers.{0, 1}.x'
BUT only merge together keys that have the exact same value.
Returns a new dict {merged_key: value}.
"""
# (pattern, value) -> list[set[int]] (per-star index values)
not_mapping = False
if not isinstance(mapping, dict):
mapping = {k: k for k in mapping}
not_mapping = True
bucket: dict[tuple[str, Any], list[set[int]]] = defaultdict(list)
for key, val in mapping.items():
digs = _DIGIT_RX.findall(key)
patt = _pattern_of(key)
for i, d in enumerate(digs):
if len(bucket[patt]) <= i:
bucket[patt].append(set())
bucket[patt][i].add(int(d))
bucket[patt].append(val)
out_items = {}
for patt, values in bucket.items():
sets, val = values[:-1], values[-1]
parts = patt.split("*") # stars are between parts
final = parts[0]
for i in range(1, len(parts)):
if i - 1 < len(sets) and sets[i - 1]:
insert = _fmt_indices(sorted(sets[i - 1]))
if len(sets[i - 1]) > 1:
final += "{" + insert + "}"
else:
final += insert
else:
final += "*"
final += parts[i]
out_items[final] = val
out = OrderedDict(out_items)
if not_mapping:
return out.keys()
return out
_ansi_re = re.compile(r"\x1b\[[0-9;]*m")
def _strip_ansi(s: str) -> str:
return _ansi_re.sub("", str(s))
def _pad(text, width):
t = str(text)
pad = max(0, width - len(_strip_ansi(t)))
return t + " " * pad
def _make_table(rows, headers):
# compute display widths while ignoring ANSI codes
cols = list(zip(*([headers] + rows))) if rows else [headers]
widths = [max(len(_strip_ansi(x)) for x in col) for col in cols]
header_line = " | ".join(_pad(h, w) for h, w in zip(headers, widths))
sep_line = "-+-".join("-" * w for w in widths)
body = [" | ".join(_pad(c, w) for c, w in zip(r, widths)) for r in rows]
return "\n".join([header_line, sep_line] + body)
PALETTE = {
"reset": "",
"red": "",
"yellow": "",
"orange": "",
"purple": "",
"bold": "",
"italic": "",
"dim": "",
}
def _color(s, color):
"""Return color-formatted input `s` if `sys.stdout` is interactive, e.g. connected to a terminal."""
if sys.stdout.isatty():
return f"{PALETTE[color]}{s}{PALETTE['reset']}"
else:
return s
def _get_terminal_width(default=80):
try:
return shutil.get_terminal_size().columns
except Exception:
return default
@dataclass
class LoadStateDictInfo:
"""
Mutable container for state-dict loading results and diagnostics. Each entry in this structure is mutable,
and will usually be mutated in-place during the loading pipeline.
Attributes:
missing_keys (`set[str]`):
Keys that are missing from the loaded checkpoints but expected in the model's architecture.
unexpected_keys (`set[str]`):
Keys that are found in the checkpoints, but not expected in the model's architecture.
mismatched_keys (`set[tuple[str, tuple[int], tuple[int]]]`):
Keys that are found in the checkpoints and are expected in the model's architecture, but with a different shape.
error_msgs ( `list[str]`):
Some potential error messages.
conversion_errors (`dict[str, str]`):
Errors happening during the on-the-fly weight conversion process.
"""
missing_keys: set[str]
unexpected_keys: set[str]
mismatched_keys: set[tuple[str, tuple[int], tuple[int]]]
error_msgs: list[str]
conversion_errors: dict[str, str]
def missing_and_mismatched(self):
"""Return all effective missing keys, including `missing` and `mismatched` keys."""
return self.missing_keys | {k[0] for k in self.mismatched_keys}
def to_dict(self):
# Does not include the `conversion_errors` to be coherent with legacy reporting in the tests
return {
"missing_keys": self.missing_keys,
"unexpected_keys": self.unexpected_keys,
"mismatched_keys": self.mismatched_keys,
"error_msgs": self.error_msgs,
}
def create_loading_report(self) -> str | None:
"""Generate the minimal table of a loading report."""
term_w = _get_terminal_width()
rows = []
tips = ""
if self.unexpected_keys:
tips += (
f"\n- {_color('UNEXPECTED', 'orange') + PALETTE['italic']}\t:can be ignored when loading from different "
"task/architecture; not ok if you expect identical arch."
)
for k in update_key_name(self.unexpected_keys):
status = _color("UNEXPECTED", "orange")
rows.append([k, status, "", ""])
if self.missing_keys:
tips += (
f"\n- {_color('MISSING', 'red') + PALETTE['italic']}\t:those params were newly initialized because missing "
"from the checkpoint. Consider training on your downstream task."
)
for k in update_key_name(self.missing_keys):
status = _color("MISSING", "red")
rows.append([k, status, ""])
if self.mismatched_keys:
tips += (
f"\n- {_color('MISMATCH', 'yellow') + PALETTE['italic']}\t:ckpt weights were loaded, but they did not match "
"the original empty weight shapes."
)
iterator = {a: (b, c) for a, b, c in self.mismatched_keys}
for key, (shape_ckpt, shape_model) in update_key_name(iterator).items():
status = _color("MISMATCH", "yellow")
data = [
key,
status,
f"Reinit due to size mismatch - ckpt: {str(shape_ckpt)} vs model:{str(shape_model)}",
]
rows.append(data)
if self.conversion_errors:
tips += f"\n- {_color('CONVERSION', 'purple') + PALETTE['italic']}\t:originate from the conversion scheme"
for k, v in update_key_name(self.conversion_errors).items():
status = _color("CONVERSION", "purple")
_details = f"\n\n{v}\n\n"
rows.append([k, status, _details])
# If nothing is wrong, return None
if len(rows) == 0:
return None
headers = ["Key", "Status"]
if term_w > 200:
headers += ["Details"]
else:
headers += ["", ""]
table = _make_table(rows, headers=headers)
tips = f"\n\n{PALETTE['italic']}Notes:{tips}{PALETTE['reset']}"
report = table + tips
return report
def log_state_dict_report(
model,
pretrained_model_name_or_path: str,
ignore_mismatched_sizes: bool,
loading_info: LoadStateDictInfo,
logger: logging.Logger | None = None,
):
"""
Log a readable report about state_dict loading issues.
This version is terminal-size aware: for very small terminals it falls back to a compact
Key | Status view so output doesn't wrap badly.
"""
if logger is None:
logger = logging.getLogger(__name__)
# Re-raise errors early if needed
if loading_info.error_msgs:
error_msg = "\n\t".join(loading_info.error_msgs)
if "size mismatch" in error_msg:
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True` to `from_pretrained(...)` if appropriate."
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
# Create the report table
report = loading_info.create_loading_report()
if report is None:
return
prelude = f"{PALETTE['bold']}{model.__class__.__name__} LOAD REPORT{PALETTE['reset']} from: {pretrained_model_name_or_path}\n"
# Log the report as warning
logger.warning(prelude + report)
# Re-raise in those case, after the report
if loading_info.conversion_errors:
raise RuntimeError(
"We encountered some issues during automatic conversion of the weights. For details look at the `CONVERSION` entries of "
"the above report!"
)
if not ignore_mismatched_sizes and loading_info.mismatched_keys:
raise RuntimeError(
"You set `ignore_mismatched_sizes` to `False`, thus raising an error. For details look at the above report!"
)