You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

281 lines
9.8 KiB

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
import shutil
import sys
from collections import OrderedDict, defaultdict
from dataclasses import dataclass
from typing import Any
_DIGIT_RX = re.compile(r"(?<=\.)(\d+)(?=\.|$)") # numbers between dots or at the end
def _pattern_of(key: str) -> str:
"""Replace every dot-delimited integer with '*' to get the structure."""
return _DIGIT_RX.sub("*", key)
def _fmt_indices(values: list[int], cutoff=10) -> str:
"""Format a list of ints as single number, {a, ..., b}, or first...last."""
if len(values) == 1:
return str(values[0])
values = sorted(values)
if len(values) > cutoff:
return f"{values[0]}...{values[-1]}"
return ", ".join(map(str, values))
def update_key_name(mapping: dict[str, Any]) -> dict[str, Any]:
"""
Merge keys like 'layers.0.x', 'layers.1.x' into 'layers.{0, 1}.x'
BUT only merge together keys that have the exact same value.
Returns a new dict {merged_key: value}.
"""
# (pattern, value) -> list[set[int]] (per-star index values)
not_mapping = False
if not isinstance(mapping, dict):
mapping = {k: k for k in mapping}
not_mapping = True
bucket: dict[tuple[str, Any], list[set[int]]] = defaultdict(list)
for key, val in mapping.items():
digs = _DIGIT_RX.findall(key)
patt = _pattern_of(key)
for i, d in enumerate(digs):
if len(bucket[patt]) <= i:
bucket[patt].append(set())
bucket[patt][i].add(int(d))
bucket[patt].append(val)
out_items = {}
for patt, values in bucket.items():
sets, val = values[:-1], values[-1]
parts = patt.split("*") # stars are between parts
final = parts[0]
for i in range(1, len(parts)):
if i - 1 < len(sets) and sets[i - 1]:
insert = _fmt_indices(sorted(sets[i - 1]))
if len(sets[i - 1]) > 1:
final += "{" + insert + "}"
else:
final += insert
else:
final += "*"
final += parts[i]
out_items[final] = val
out = OrderedDict(out_items)
if not_mapping:
return out.keys()
return out
_ansi_re = re.compile(r"\x1b\[[0-9;]*m")
def _strip_ansi(s: str) -> str:
return _ansi_re.sub("", str(s))
def _pad(text, width):
t = str(text)
pad = max(0, width - len(_strip_ansi(t)))
return t + " " * pad
def _make_table(rows, headers):
# compute display widths while ignoring ANSI codes
cols = list(zip(*([headers] + rows))) if rows else [headers]
widths = [max(len(_strip_ansi(x)) for x in col) for col in cols]
header_line = " | ".join(_pad(h, w) for h, w in zip(headers, widths))
sep_line = "-+-".join("-" * w for w in widths)
body = [" | ".join(_pad(c, w) for c, w in zip(r, widths)) for r in rows]
return "\n".join([header_line, sep_line] + body)
PALETTE = {
"reset": "",
"red": "",
"yellow": "",
"orange": "",
"purple": "",
"bold": "",
"italic": "",
"dim": "",
}
def _color(s, color):
"""Return color-formatted input `s` if `sys.stdout` is interactive, e.g. connected to a terminal."""
if sys.stdout.isatty():
return f"{PALETTE[color]}{s}{PALETTE['reset']}"
else:
return s
def _get_terminal_width(default=80):
try:
return shutil.get_terminal_size().columns
except Exception:
return default
@dataclass
class LoadStateDictInfo:
"""
Mutable container for state-dict loading results and diagnostics. Each entry in this structure is mutable,
and will usually be mutated in-place during the loading pipeline.
Attributes:
missing_keys (`set[str]`):
Keys that are missing from the loaded checkpoints but expected in the model's architecture.
unexpected_keys (`set[str]`):
Keys that are found in the checkpoints, but not expected in the model's architecture.
mismatched_keys (`set[tuple[str, tuple[int], tuple[int]]]`):
Keys that are found in the checkpoints and are expected in the model's architecture, but with a different shape.
error_msgs ( `list[str]`):
Some potential error messages.
conversion_errors (`dict[str, str]`):
Errors happening during the on-the-fly weight conversion process.
"""
missing_keys: set[str]
unexpected_keys: set[str]
mismatched_keys: set[tuple[str, tuple[int], tuple[int]]]
error_msgs: list[str]
conversion_errors: dict[str, str]
def missing_and_mismatched(self):
"""Return all effective missing keys, including `missing` and `mismatched` keys."""
return self.missing_keys | {k[0] for k in self.mismatched_keys}
def to_dict(self):
# Does not include the `conversion_errors` to be coherent with legacy reporting in the tests
return {
"missing_keys": self.missing_keys,
"unexpected_keys": self.unexpected_keys,
"mismatched_keys": self.mismatched_keys,
"error_msgs": self.error_msgs,
}
def create_loading_report(self) -> str | None:
"""Generate the minimal table of a loading report."""
term_w = _get_terminal_width()
rows = []
tips = ""
if self.unexpected_keys:
tips += (
f"\n- {_color('UNEXPECTED', 'orange') + PALETTE['italic']}\t:can be ignored when loading from different "
"task/architecture; not ok if you expect identical arch."
)
for k in update_key_name(self.unexpected_keys):
status = _color("UNEXPECTED", "orange")
rows.append([k, status, "", ""])
if self.missing_keys:
tips += (
f"\n- {_color('MISSING', 'red') + PALETTE['italic']}\t:those params were newly initialized because missing "
"from the checkpoint. Consider training on your downstream task."
)
for k in update_key_name(self.missing_keys):
status = _color("MISSING", "red")
rows.append([k, status, ""])
if self.mismatched_keys:
tips += (
f"\n- {_color('MISMATCH', 'yellow') + PALETTE['italic']}\t:ckpt weights were loaded, but they did not match "
"the original empty weight shapes."
)
iterator = {a: (b, c) for a, b, c in self.mismatched_keys}
for key, (shape_ckpt, shape_model) in update_key_name(iterator).items():
status = _color("MISMATCH", "yellow")
data = [
key,
status,
f"Reinit due to size mismatch - ckpt: {str(shape_ckpt)} vs model:{str(shape_model)}",
]
rows.append(data)
if self.conversion_errors:
tips += f"\n- {_color('CONVERSION', 'purple') + PALETTE['italic']}\t:originate from the conversion scheme"
for k, v in update_key_name(self.conversion_errors).items():
status = _color("CONVERSION", "purple")
_details = f"\n\n{v}\n\n"
rows.append([k, status, _details])
# If nothing is wrong, return None
if len(rows) == 0:
return None
headers = ["Key", "Status"]
if term_w > 200:
headers += ["Details"]
else:
headers += ["", ""]
table = _make_table(rows, headers=headers)
tips = f"\n\n{PALETTE['italic']}Notes:{tips}{PALETTE['reset']}"
report = table + tips
return report
def log_state_dict_report(
model,
pretrained_model_name_or_path: str,
ignore_mismatched_sizes: bool,
loading_info: LoadStateDictInfo,
logger: logging.Logger | None = None,
):
"""
Log a readable report about state_dict loading issues.
This version is terminal-size aware: for very small terminals it falls back to a compact
Key | Status view so output doesn't wrap badly.
"""
if logger is None:
logger = logging.getLogger(__name__)
# Re-raise errors early if needed
if loading_info.error_msgs:
error_msg = "\n\t".join(loading_info.error_msgs)
if "size mismatch" in error_msg:
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True` to `from_pretrained(...)` if appropriate."
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
# Create the report table
report = loading_info.create_loading_report()
if report is None:
return
prelude = f"{PALETTE['bold']}{model.__class__.__name__} LOAD REPORT{PALETTE['reset']} from: {pretrained_model_name_or_path}\n"
# Log the report as warning
logger.warning(prelude + report)
# Re-raise in those case, after the report
if loading_info.conversion_errors:
raise RuntimeError(
"We encountered some issues during automatic conversion of the weights. For details look at the `CONVERSION` entries of "
"the above report!"
)
if not ignore_mismatched_sizes and loading_info.mismatched_keys:
raise RuntimeError(
"You set `ignore_mismatched_sizes` to `False`, thus raising an error. For details look at the above report!"
)