You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
112 lines
4.3 KiB
112 lines
4.3 KiB
import functools
|
|
import operator
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from typing import Literal, Optional
|
|
|
|
|
|
FILENAME_T = str
|
|
TENSOR_NAME_T = str
|
|
DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"]
|
|
|
|
|
|
@dataclass
|
|
class TensorInfo:
|
|
"""Information about a tensor.
|
|
|
|
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
|
|
|
|
Attributes:
|
|
dtype (`str`):
|
|
The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL").
|
|
shape (`list[int]`):
|
|
The shape of the tensor.
|
|
data_offsets (`tuple[int, int]`):
|
|
The offsets of the data in the file as a tuple `[BEGIN, END]`.
|
|
parameter_count (`int`):
|
|
The number of parameters in the tensor.
|
|
"""
|
|
|
|
dtype: DTYPE_T
|
|
shape: list[int]
|
|
data_offsets: tuple[int, int]
|
|
parameter_count: int = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
# Taken from https://stackoverflow.com/a/13840436
|
|
try:
|
|
self.parameter_count = functools.reduce(operator.mul, self.shape)
|
|
except TypeError:
|
|
self.parameter_count = 1 # scalar value has no shape
|
|
|
|
|
|
@dataclass
|
|
class SafetensorsFileMetadata:
|
|
"""Metadata for a Safetensors file hosted on the Hub.
|
|
|
|
This class is returned by [`parse_safetensors_file_metadata`].
|
|
|
|
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
|
|
|
|
Attributes:
|
|
metadata (`dict`):
|
|
The metadata contained in the file.
|
|
tensors (`dict[str, TensorInfo]`):
|
|
A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a
|
|
[`TensorInfo`] object.
|
|
parameter_count (`dict[str, int]`):
|
|
A map of the number of parameters per data type. Keys are data types and values are the number of parameters
|
|
of that data type.
|
|
"""
|
|
|
|
metadata: dict[str, str]
|
|
tensors: dict[TENSOR_NAME_T, TensorInfo]
|
|
parameter_count: dict[DTYPE_T, int] = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
parameter_count: dict[DTYPE_T, int] = defaultdict(int)
|
|
for tensor in self.tensors.values():
|
|
parameter_count[tensor.dtype] += tensor.parameter_count
|
|
self.parameter_count = dict(parameter_count)
|
|
|
|
|
|
@dataclass
|
|
class SafetensorsRepoMetadata:
|
|
"""Metadata for a Safetensors repo.
|
|
|
|
A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared
|
|
model) or a 'model.safetensors.index.json' index file (sharded model) at its root.
|
|
|
|
This class is returned by [`get_safetensors_metadata`].
|
|
|
|
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
|
|
|
|
Attributes:
|
|
metadata (`dict`, *optional*):
|
|
The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded
|
|
models.
|
|
sharded (`bool`):
|
|
Whether the repo contains a sharded model or not.
|
|
weight_map (`dict[str, str]`):
|
|
A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors.
|
|
files_metadata (`dict[str, SafetensorsFileMetadata]`):
|
|
A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as
|
|
a [`SafetensorsFileMetadata`] object.
|
|
parameter_count (`dict[str, int]`):
|
|
A map of the number of parameters per data type. Keys are data types and values are the number of parameters
|
|
of that data type.
|
|
"""
|
|
|
|
metadata: Optional[dict]
|
|
sharded: bool
|
|
weight_map: dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename
|
|
files_metadata: dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata
|
|
parameter_count: dict[DTYPE_T, int] = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
parameter_count: dict[DTYPE_T, int] = defaultdict(int)
|
|
for file_metadata in self.files_metadata.values():
|
|
for dtype, nb_parameters_ in file_metadata.parameter_count.items():
|
|
parameter_count[dtype] += nb_parameters_
|
|
self.parameter_count = dict(parameter_count)
|