# Copyright 2025 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Processor class for SAM2. """ from copy import deepcopy import numpy as np from ...image_utils import ImageInput from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding from ...utils import TensorType, auto_docstring, is_torch_available, logging from ...utils.import_utils import requires logger = logging.get_logger(__name__) if is_torch_available(): import torch @requires(backends=("torch",)) @auto_docstring class Sam2Processor(ProcessorMixin): def __init__(self, image_processor, target_size: int | None = None, point_pad_value: int = -10, **kwargs): r""" target_size (`int`, *optional*): The target size (in pixels) for normalizing input points and bounding boxes. If not provided, defaults to the image processor's size configuration. All input coordinates (points and boxes) are normalized to this size before being passed to the model. This ensures consistent coordinate representation regardless of the original image dimensions. point_pad_value (`int`, *optional*, defaults to -10): The value used for padding input points when batching sequences of different lengths. This value is used to mark padded positions and is preserved during coordinate normalization. """ super().__init__(image_processor, **kwargs) self.point_pad_value = point_pad_value self.target_size = target_size if target_size is not None else self.image_processor.size["height"] @auto_docstring def __call__( self, images: ImageInput | None = None, segmentation_maps: ImageInput | None = None, input_points: list[list[list[list[float]]]] | torch.Tensor | None = None, input_labels: list[list[list[int]]] | torch.Tensor | None = None, input_boxes: list[list[list[float]]] | torch.Tensor | None = None, original_sizes: list[list[float]] | torch.Tensor | None = None, return_tensors: str | TensorType | None = None, **kwargs, ) -> BatchEncoding: r""" segmentation_maps (`ImageInput`, *optional*): The segmentation maps to process. input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*): The points to add to the frame. input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*): The labels for the points. input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*): The bounding boxes to add to the frame. original_sizes (`list[list[float]]`, `torch.Tensor`, *optional*): The original sizes of the images. Returns: A [`BatchEncoding`] with the following fields: - `pixel_values` (`torch.Tensor`): The processed image(s). - `original_sizes` (`list[list[float]]`): The original sizes of the images. - `labels` (`torch.Tensor`): The processed segmentation maps (if provided). - `input_points` (`torch.Tensor`): The processed points. - `input_labels` (`torch.Tensor`): The processed labels. - `input_boxes` (`torch.Tensor`): The processed bounding boxes. """ if images is not None: encoding_image_processor = self.image_processor( images, segmentation_maps=segmentation_maps, return_tensors=return_tensors, **kwargs, ) elif original_sizes is not None: if isinstance(original_sizes, torch.Tensor): original_sizes = original_sizes.cpu().tolist() encoding_image_processor = BatchEncoding({"original_sizes": original_sizes}, tensor_type=return_tensors) else: raise ValueError("Either images or original_sizes must be provided") # pop arguments that are not used in the forward but used nevertheless original_sizes = encoding_image_processor["original_sizes"] # Check original_sizes is of length 1 or len(images) if images is not None and len(original_sizes) != 1 and len(original_sizes) != len(images): raise ValueError( "original_sizes must be of length 1 or len(images). If you are passing a single image, you must pass a single original_size." ) # Process input points, labels, and boxes if provided if input_points is not None or input_labels is not None or input_boxes is not None: # Validate and convert inputs to standardized format processed_points = self._validate_single_input( input_points, expected_depth=4, input_name="points", expected_format="[image level, object level, point level, point coordinates]", expected_coord_size=2, ) processed_labels = self._validate_single_input( input_labels, expected_depth=3, input_name="labels", expected_format="[image level, object level, point level]", ) processed_boxes = self._validate_single_input( input_boxes, expected_depth=3, input_name="boxes", expected_format="[image level, box level, box coordinates]", expected_coord_size=4, ) # Get padding requirements for all inputs if processed_points is not None: points_max_dims = self._get_nested_dimensions(processed_points)[:3] if processed_labels is not None: labels_max_dims = self._get_nested_dimensions(processed_labels)[:3] if processed_boxes is not None: boxes_max_dims = self._get_nested_dimensions(processed_boxes)[:2] # Ensure points and labels have consistent dimensions if processed_points is not None and processed_labels is not None: if points_max_dims != labels_max_dims: raise ValueError( "Input points and labels have inconsistent dimensions. Please ensure they have the same dimensions." ) # Check that boxes don't need padding (model limitation) if processed_boxes is not None and len(processed_boxes) >= 2: if any(len(img_boxes) < boxes_max_dims[1] for img_boxes in processed_boxes): raise ValueError( "Input boxes have inconsistent dimensions that would require padding, " "but boxes cannot be padded due to model limitations. " "Please ensure all images have the same number of boxes." ) # Pad and normalize all inputs to final tensor format if processed_points is not None: padded_points = self._pad_nested_list(processed_points, points_max_dims + [2]) final_points = torch.tensor(padded_points, dtype=torch.float32) self._normalize_tensor_coordinates(final_points, original_sizes, preserve_padding=True) encoding_image_processor.update({"input_points": final_points}) if processed_labels is not None: padded_labels = self._pad_nested_list(processed_labels, labels_max_dims) final_labels = torch.tensor(padded_labels, dtype=torch.int64) encoding_image_processor.update({"input_labels": final_labels}) if processed_boxes is not None: final_boxes = torch.tensor(processed_boxes, dtype=torch.float32) self._normalize_tensor_coordinates(final_boxes, original_sizes, is_bounding_box=True) encoding_image_processor.update({"input_boxes": final_boxes}) return encoding_image_processor def _normalize_coordinates( self, target_size: int, coords: "torch.Tensor", original_size, is_bounding_box=False ) -> "torch.Tensor": """ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. Args: target_size (`int`): The target size of the image. coords (`torch.Tensor`): The coordinates to be normalized. original_size (`tuple`): The original size of the image. is_bounding_box (`bool`, *optional*, defaults to `False`): Whether the coordinates are bounding boxes. """ old_h, old_w = original_size new_h, new_w = target_size, target_size coords = deepcopy(coords).float() if is_bounding_box: coords = coords.reshape(-1, 2, 2) coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 1] = coords[..., 1] * (new_h / old_h) if is_bounding_box: coords = coords.reshape(-1, 4) return coords def _convert_to_nested_list(self, data, expected_depth, current_depth=0): """ Recursively convert various input formats (tensors, numpy arrays, lists) to nested lists. Args: data: Input data in any format expected_depth: Expected nesting depth current_depth: Current depth in recursion Returns: Nested list representation of the data """ if data is None: return None # Convert tensor/numpy to list if we're at a leaf level or if it's a multi-dimensional array if isinstance(data, torch.Tensor): # PyTorch tensor if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small tensor return data.numpy().tolist() else: return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] elif isinstance(data, np.ndarray): # NumPy array if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small array return data.tolist() else: return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] elif isinstance(data, list): if current_depth == expected_depth: # We've reached the expected depth, return as is return data else: # Continue recursion return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] elif isinstance(data, (int, float)): return data else: raise TypeError(f"Unsupported data type: {type(data)}") def _get_nested_dimensions(self, nested_list, max_dims=None): """ Get the maximum dimensions at each level of nesting. Args: nested_list (`list`): Nested list structure. max_dims (`list`, *optional*): Current maximum dimensions (for recursion). Returns: `list`: A list of maximum dimensions for each nesting level. """ if max_dims is None: max_dims = [] if not isinstance(nested_list, list): return max_dims if len(max_dims) == 0: max_dims.append(len(nested_list)) else: max_dims[0] = max(max_dims[0], len(nested_list)) if len(nested_list) > 0: for item in nested_list: if isinstance(item, list): sub_dims = self._get_nested_dimensions(item) # Merge sub_dims into max_dims for i, dim in enumerate(sub_dims): if i + 1 >= len(max_dims): max_dims.append(dim) else: max_dims[i + 1] = max(max_dims[i + 1], dim) return max_dims def _pad_nested_list(self, nested_list, target_dims, current_level=0, pad_value=None): """ Recursively pad a nested list to match target dimensions. Args: nested_list (`list`): Nested list to pad. target_dims (`list`): Target dimensions for each level. current_level (`int`, *optional*, defaults to 0): Current nesting level. pad_value (`int`, *optional*): Value to use for padding. Returns: `list`: The padded nested list. """ if pad_value is None: pad_value = self.point_pad_value if current_level >= len(target_dims): return nested_list # Ensure we have a list if not isinstance(nested_list, list): nested_list = [nested_list] # Pad current level current_size = len(nested_list) target_size = target_dims[current_level] # Pad with appropriate values if current_level == len(target_dims) - 1: # At the coordinate level, pad with pad_value nested_list.extend([pad_value] * (target_size - current_size)) else: # At higher levels, pad with nested structures if current_size > 0: # Create appropriately sized template if current_level < len(target_dims) - 2: # For non-coordinate levels, create empty nested structure template_dims = target_dims[current_level + 1 :] template = self._create_empty_nested_structure(template_dims, pad_value) else: # For coordinate level, create list of pad_values template = [pad_value] * target_dims[current_level + 1] nested_list.extend([deepcopy(template) for _ in range(target_size - current_size)]) else: # Create from scratch template_dims = target_dims[current_level + 1 :] template = self._create_empty_nested_structure(template_dims, pad_value) nested_list.extend([deepcopy(template) for _ in range(target_size)]) # Recursively pad sublists if current_level < len(target_dims) - 1: for i in range(len(nested_list)): if isinstance(nested_list[i], list): nested_list[i] = self._pad_nested_list(nested_list[i], target_dims, current_level + 1, pad_value) return nested_list def _create_empty_nested_structure(self, dims, pad_value): """ Create an empty nested structure with given dimensions filled with pad_value. Args: dims (`list`): The dimensions of the nested structure. pad_value (`int`): The value to fill the structure with. """ if len(dims) == 1: return [pad_value] * dims[0] else: return [self._create_empty_nested_structure(dims[1:], pad_value) for _ in range(dims[0])] def _get_nesting_level(self, input_list): """ Get the nesting level of a list structure. Args: input_list (`list`): The list to get the nesting level of. """ if isinstance(input_list, list): if len(input_list) == 0: return 1 return 1 + self._get_nesting_level(input_list[0]) elif isinstance(input_list, (np.ndarray, torch.Tensor)): # For arrays/tensors, the nesting level is the number of dimensions return len(input_list.shape) return 0 def _validate_single_input( self, data: torch.Tensor | np.ndarray | list, expected_depth: int, input_name: str, expected_format: str, expected_coord_size: int | None = None, ) -> list: """ Validate a single input by ensuring proper nesting and raising an error if the input is not valid. Args: data (`torch.Tensor`, `np.ndarray`, or `list`): Input data to process. expected_depth (`int`): Expected nesting depth. input_name (`str`): Name of the input for error messages. expected_format (`str`): The expected format of the input. expected_coord_size (`int`, *optional*): Expected coordinate size (2 for points, 4 for boxes, None for labels). . """ if data is None: return None # Handle tensors and numpy arrays first if isinstance(data, (torch.Tensor, np.ndarray)): # For tensors/arrays, we can directly check the number of dimensions if data.ndim != expected_depth: raise ValueError( f"Input {input_name} must be a tensor/array with {expected_depth} dimensions. The expected nesting format is {expected_format}. Got {data.ndim} dimensions." ) elif expected_coord_size is not None: if data.shape[-1] != expected_coord_size: raise ValueError( f"Input {input_name} must be a tensor/array with {expected_coord_size} as the last dimension, got {data.shape[-1]}." ) return self._convert_to_nested_list(data, expected_depth) # Handle nested lists if isinstance(data, list): current_depth = self._get_nesting_level(data) if current_depth != expected_depth: raise ValueError( f"Input {input_name} must be a nested list with {expected_depth} levels. The expected nesting format is {expected_format}. Got {current_depth} levels." ) return self._convert_to_nested_list(data, expected_depth) def _normalize_tensor_coordinates(self, tensor, original_sizes, is_bounding_box=False, preserve_padding=False): """ Helper method to normalize coordinates in a tensor across multiple images. Args: tensor (`torch.Tensor`): Input tensor with coordinates. original_sizes (`list`): Original image sizes. is_bounding_box (`bool`, *optional*, defaults to `False`): Whether coordinates are bounding boxes. preserve_padding (`bool`, *optional*, defaults to `False`): Whether to preserve padding values (for points). """ if preserve_padding: # For points: avoid normalizing pad values mask = tensor != self.point_pad_value coord_mask = mask.all(dim=-1, keepdim=True) for img_idx in range(len(original_sizes)): if img_idx < tensor.shape[0]: original_size = original_sizes[img_idx] if img_idx < len(original_sizes) else original_sizes[0] normalized_coords = self._normalize_coordinates( self.target_size, tensor[img_idx], original_size, is_bounding_box=is_bounding_box ) if preserve_padding: # Only update non-padded values img_mask = coord_mask[img_idx] tensor[img_idx] = torch.where( img_mask.expand_as(tensor[img_idx]), normalized_coords, tensor[img_idx] ) else: tensor[img_idx] = normalized_coords def post_process_masks( self, masks, original_sizes, mask_threshold=0.0, binarize=True, max_hole_area=0.0, max_sprinkle_area=0.0, apply_non_overlapping_constraints=False, **kwargs, ): """ Remove padding and upscale masks to the original image size. Args: masks (`Union[List[torch.Tensor], List[np.ndarray]]`): Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`): The original sizes of each image before it was resized to the model's expected input shape, in (height, width) format. mask_threshold (`float`, *optional*, defaults to 0.0): Threshold for binarization and post-processing operations. binarize (`bool`, *optional*, defaults to `True`): Whether to binarize the masks. max_hole_area (`float`, *optional*, defaults to 0.0): The maximum area of a hole to fill. max_sprinkle_area (`float`, *optional*, defaults to 0.0): The maximum area of a sprinkle to fill. apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`): Whether to apply non-overlapping constraints to the masks. Returns: (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is given by original_size. """ return self.image_processor.post_process_masks( masks, original_sizes, mask_threshold, binarize, max_hole_area, max_sprinkle_area, apply_non_overlapping_constraints, **kwargs, ) @property def model_input_names(self): image_processor_input_names = self.image_processor.model_input_names return list(image_processor_input_names + ["original_sizes"]) __all__ = ["Sam2Processor"]