You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
512 lines
22 KiB
512 lines
22 KiB
|
4 days ago
|
# Copyright 2025 The HuggingFace Inc. team.
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
"""
|
||
|
|
Processor class for SAM2.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from copy import deepcopy
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from ...image_utils import ImageInput
|
||
|
|
from ...processing_utils import ProcessorMixin
|
||
|
|
from ...tokenization_utils_base import BatchEncoding
|
||
|
|
from ...utils import TensorType, auto_docstring, is_torch_available, logging
|
||
|
|
from ...utils.import_utils import requires
|
||
|
|
|
||
|
|
|
||
|
|
logger = logging.get_logger(__name__)
|
||
|
|
|
||
|
|
if is_torch_available():
|
||
|
|
import torch
|
||
|
|
|
||
|
|
|
||
|
|
@requires(backends=("torch",))
|
||
|
|
@auto_docstring
|
||
|
|
class Sam2Processor(ProcessorMixin):
|
||
|
|
def __init__(self, image_processor, target_size: int | None = None, point_pad_value: int = -10, **kwargs):
|
||
|
|
r"""
|
||
|
|
target_size (`int`, *optional*):
|
||
|
|
The target size (in pixels) for normalizing input points and bounding boxes. If not provided, defaults
|
||
|
|
to the image processor's size configuration. All input coordinates (points and boxes) are normalized
|
||
|
|
to this size before being passed to the model. This ensures consistent coordinate representation
|
||
|
|
regardless of the original image dimensions.
|
||
|
|
point_pad_value (`int`, *optional*, defaults to -10):
|
||
|
|
The value used for padding input points when batching sequences of different lengths. This value is
|
||
|
|
used to mark padded positions and is preserved during coordinate normalization.
|
||
|
|
"""
|
||
|
|
super().__init__(image_processor, **kwargs)
|
||
|
|
self.point_pad_value = point_pad_value
|
||
|
|
self.target_size = target_size if target_size is not None else self.image_processor.size["height"]
|
||
|
|
|
||
|
|
@auto_docstring
|
||
|
|
def __call__(
|
||
|
|
self,
|
||
|
|
images: ImageInput | None = None,
|
||
|
|
segmentation_maps: ImageInput | None = None,
|
||
|
|
input_points: list[list[list[list[float]]]] | torch.Tensor | None = None,
|
||
|
|
input_labels: list[list[list[int]]] | torch.Tensor | None = None,
|
||
|
|
input_boxes: list[list[list[float]]] | torch.Tensor | None = None,
|
||
|
|
original_sizes: list[list[float]] | torch.Tensor | None = None,
|
||
|
|
return_tensors: str | TensorType | None = None,
|
||
|
|
**kwargs,
|
||
|
|
) -> BatchEncoding:
|
||
|
|
r"""
|
||
|
|
segmentation_maps (`ImageInput`, *optional*):
|
||
|
|
The segmentation maps to process.
|
||
|
|
input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*):
|
||
|
|
The points to add to the frame.
|
||
|
|
input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*):
|
||
|
|
The labels for the points.
|
||
|
|
input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*):
|
||
|
|
The bounding boxes to add to the frame.
|
||
|
|
original_sizes (`list[list[float]]`, `torch.Tensor`, *optional*):
|
||
|
|
The original sizes of the images.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A [`BatchEncoding`] with the following fields:
|
||
|
|
- `pixel_values` (`torch.Tensor`): The processed image(s).
|
||
|
|
- `original_sizes` (`list[list[float]]`): The original sizes of the images.
|
||
|
|
- `labels` (`torch.Tensor`): The processed segmentation maps (if provided).
|
||
|
|
- `input_points` (`torch.Tensor`): The processed points.
|
||
|
|
- `input_labels` (`torch.Tensor`): The processed labels.
|
||
|
|
- `input_boxes` (`torch.Tensor`): The processed bounding boxes.
|
||
|
|
"""
|
||
|
|
if images is not None:
|
||
|
|
encoding_image_processor = self.image_processor(
|
||
|
|
images,
|
||
|
|
segmentation_maps=segmentation_maps,
|
||
|
|
return_tensors=return_tensors,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
elif original_sizes is not None:
|
||
|
|
if isinstance(original_sizes, torch.Tensor):
|
||
|
|
original_sizes = original_sizes.cpu().tolist()
|
||
|
|
encoding_image_processor = BatchEncoding({"original_sizes": original_sizes}, tensor_type=return_tensors)
|
||
|
|
else:
|
||
|
|
raise ValueError("Either images or original_sizes must be provided")
|
||
|
|
|
||
|
|
# pop arguments that are not used in the forward but used nevertheless
|
||
|
|
original_sizes = encoding_image_processor["original_sizes"]
|
||
|
|
# Check original_sizes is of length 1 or len(images)
|
||
|
|
if images is not None and len(original_sizes) != 1 and len(original_sizes) != len(images):
|
||
|
|
raise ValueError(
|
||
|
|
"original_sizes must be of length 1 or len(images). If you are passing a single image, you must pass a single original_size."
|
||
|
|
)
|
||
|
|
|
||
|
|
# Process input points, labels, and boxes if provided
|
||
|
|
if input_points is not None or input_labels is not None or input_boxes is not None:
|
||
|
|
# Validate and convert inputs to standardized format
|
||
|
|
processed_points = self._validate_single_input(
|
||
|
|
input_points,
|
||
|
|
expected_depth=4,
|
||
|
|
input_name="points",
|
||
|
|
expected_format="[image level, object level, point level, point coordinates]",
|
||
|
|
expected_coord_size=2,
|
||
|
|
)
|
||
|
|
processed_labels = self._validate_single_input(
|
||
|
|
input_labels,
|
||
|
|
expected_depth=3,
|
||
|
|
input_name="labels",
|
||
|
|
expected_format="[image level, object level, point level]",
|
||
|
|
)
|
||
|
|
processed_boxes = self._validate_single_input(
|
||
|
|
input_boxes,
|
||
|
|
expected_depth=3,
|
||
|
|
input_name="boxes",
|
||
|
|
expected_format="[image level, box level, box coordinates]",
|
||
|
|
expected_coord_size=4,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Get padding requirements for all inputs
|
||
|
|
if processed_points is not None:
|
||
|
|
points_max_dims = self._get_nested_dimensions(processed_points)[:3]
|
||
|
|
if processed_labels is not None:
|
||
|
|
labels_max_dims = self._get_nested_dimensions(processed_labels)[:3]
|
||
|
|
if processed_boxes is not None:
|
||
|
|
boxes_max_dims = self._get_nested_dimensions(processed_boxes)[:2]
|
||
|
|
|
||
|
|
# Ensure points and labels have consistent dimensions
|
||
|
|
if processed_points is not None and processed_labels is not None:
|
||
|
|
if points_max_dims != labels_max_dims:
|
||
|
|
raise ValueError(
|
||
|
|
"Input points and labels have inconsistent dimensions. Please ensure they have the same dimensions."
|
||
|
|
)
|
||
|
|
|
||
|
|
# Check that boxes don't need padding (model limitation)
|
||
|
|
if processed_boxes is not None and len(processed_boxes) >= 2:
|
||
|
|
if any(len(img_boxes) < boxes_max_dims[1] for img_boxes in processed_boxes):
|
||
|
|
raise ValueError(
|
||
|
|
"Input boxes have inconsistent dimensions that would require padding, "
|
||
|
|
"but boxes cannot be padded due to model limitations. "
|
||
|
|
"Please ensure all images have the same number of boxes."
|
||
|
|
)
|
||
|
|
|
||
|
|
# Pad and normalize all inputs to final tensor format
|
||
|
|
if processed_points is not None:
|
||
|
|
padded_points = self._pad_nested_list(processed_points, points_max_dims + [2])
|
||
|
|
final_points = torch.tensor(padded_points, dtype=torch.float32)
|
||
|
|
self._normalize_tensor_coordinates(final_points, original_sizes, preserve_padding=True)
|
||
|
|
encoding_image_processor.update({"input_points": final_points})
|
||
|
|
|
||
|
|
if processed_labels is not None:
|
||
|
|
padded_labels = self._pad_nested_list(processed_labels, labels_max_dims)
|
||
|
|
final_labels = torch.tensor(padded_labels, dtype=torch.int64)
|
||
|
|
encoding_image_processor.update({"input_labels": final_labels})
|
||
|
|
|
||
|
|
if processed_boxes is not None:
|
||
|
|
final_boxes = torch.tensor(processed_boxes, dtype=torch.float32)
|
||
|
|
self._normalize_tensor_coordinates(final_boxes, original_sizes, is_bounding_box=True)
|
||
|
|
encoding_image_processor.update({"input_boxes": final_boxes})
|
||
|
|
|
||
|
|
return encoding_image_processor
|
||
|
|
|
||
|
|
def _normalize_coordinates(
|
||
|
|
self, target_size: int, coords: "torch.Tensor", original_size, is_bounding_box=False
|
||
|
|
) -> "torch.Tensor":
|
||
|
|
"""
|
||
|
|
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
target_size (`int`):
|
||
|
|
The target size of the image.
|
||
|
|
coords (`torch.Tensor`):
|
||
|
|
The coordinates to be normalized.
|
||
|
|
original_size (`tuple`):
|
||
|
|
The original size of the image.
|
||
|
|
is_bounding_box (`bool`, *optional*, defaults to `False`):
|
||
|
|
Whether the coordinates are bounding boxes.
|
||
|
|
"""
|
||
|
|
old_h, old_w = original_size
|
||
|
|
new_h, new_w = target_size, target_size
|
||
|
|
coords = deepcopy(coords).float()
|
||
|
|
|
||
|
|
if is_bounding_box:
|
||
|
|
coords = coords.reshape(-1, 2, 2)
|
||
|
|
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||
|
|
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||
|
|
|
||
|
|
if is_bounding_box:
|
||
|
|
coords = coords.reshape(-1, 4)
|
||
|
|
|
||
|
|
return coords
|
||
|
|
|
||
|
|
def _convert_to_nested_list(self, data, expected_depth, current_depth=0):
|
||
|
|
"""
|
||
|
|
Recursively convert various input formats (tensors, numpy arrays, lists) to nested lists.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
data: Input data in any format
|
||
|
|
expected_depth: Expected nesting depth
|
||
|
|
current_depth: Current depth in recursion
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Nested list representation of the data
|
||
|
|
"""
|
||
|
|
if data is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Convert tensor/numpy to list if we're at a leaf level or if it's a multi-dimensional array
|
||
|
|
if isinstance(data, torch.Tensor): # PyTorch tensor
|
||
|
|
if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small tensor
|
||
|
|
return data.numpy().tolist()
|
||
|
|
else:
|
||
|
|
return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
|
||
|
|
elif isinstance(data, np.ndarray): # NumPy array
|
||
|
|
if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small array
|
||
|
|
return data.tolist()
|
||
|
|
else:
|
||
|
|
return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
|
||
|
|
elif isinstance(data, list):
|
||
|
|
if current_depth == expected_depth:
|
||
|
|
# We've reached the expected depth, return as is
|
||
|
|
return data
|
||
|
|
else:
|
||
|
|
# Continue recursion
|
||
|
|
return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
|
||
|
|
elif isinstance(data, (int, float)):
|
||
|
|
return data
|
||
|
|
else:
|
||
|
|
raise TypeError(f"Unsupported data type: {type(data)}")
|
||
|
|
|
||
|
|
def _get_nested_dimensions(self, nested_list, max_dims=None):
|
||
|
|
"""
|
||
|
|
Get the maximum dimensions at each level of nesting.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
nested_list (`list`):
|
||
|
|
Nested list structure.
|
||
|
|
max_dims (`list`, *optional*):
|
||
|
|
Current maximum dimensions (for recursion).
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
`list`: A list of maximum dimensions for each nesting level.
|
||
|
|
"""
|
||
|
|
if max_dims is None:
|
||
|
|
max_dims = []
|
||
|
|
|
||
|
|
if not isinstance(nested_list, list):
|
||
|
|
return max_dims
|
||
|
|
|
||
|
|
if len(max_dims) == 0:
|
||
|
|
max_dims.append(len(nested_list))
|
||
|
|
else:
|
||
|
|
max_dims[0] = max(max_dims[0], len(nested_list))
|
||
|
|
|
||
|
|
if len(nested_list) > 0:
|
||
|
|
for item in nested_list:
|
||
|
|
if isinstance(item, list):
|
||
|
|
sub_dims = self._get_nested_dimensions(item)
|
||
|
|
# Merge sub_dims into max_dims
|
||
|
|
for i, dim in enumerate(sub_dims):
|
||
|
|
if i + 1 >= len(max_dims):
|
||
|
|
max_dims.append(dim)
|
||
|
|
else:
|
||
|
|
max_dims[i + 1] = max(max_dims[i + 1], dim)
|
||
|
|
|
||
|
|
return max_dims
|
||
|
|
|
||
|
|
def _pad_nested_list(self, nested_list, target_dims, current_level=0, pad_value=None):
|
||
|
|
"""
|
||
|
|
Recursively pad a nested list to match target dimensions.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
nested_list (`list`):
|
||
|
|
Nested list to pad.
|
||
|
|
target_dims (`list`):
|
||
|
|
Target dimensions for each level.
|
||
|
|
current_level (`int`, *optional*, defaults to 0):
|
||
|
|
Current nesting level.
|
||
|
|
pad_value (`int`, *optional*):
|
||
|
|
Value to use for padding.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
`list`: The padded nested list.
|
||
|
|
"""
|
||
|
|
if pad_value is None:
|
||
|
|
pad_value = self.point_pad_value
|
||
|
|
|
||
|
|
if current_level >= len(target_dims):
|
||
|
|
return nested_list
|
||
|
|
|
||
|
|
# Ensure we have a list
|
||
|
|
if not isinstance(nested_list, list):
|
||
|
|
nested_list = [nested_list]
|
||
|
|
|
||
|
|
# Pad current level
|
||
|
|
current_size = len(nested_list)
|
||
|
|
target_size = target_dims[current_level]
|
||
|
|
|
||
|
|
# Pad with appropriate values
|
||
|
|
if current_level == len(target_dims) - 1:
|
||
|
|
# At the coordinate level, pad with pad_value
|
||
|
|
nested_list.extend([pad_value] * (target_size - current_size))
|
||
|
|
else:
|
||
|
|
# At higher levels, pad with nested structures
|
||
|
|
if current_size > 0:
|
||
|
|
# Create appropriately sized template
|
||
|
|
if current_level < len(target_dims) - 2:
|
||
|
|
# For non-coordinate levels, create empty nested structure
|
||
|
|
template_dims = target_dims[current_level + 1 :]
|
||
|
|
template = self._create_empty_nested_structure(template_dims, pad_value)
|
||
|
|
else:
|
||
|
|
# For coordinate level, create list of pad_values
|
||
|
|
template = [pad_value] * target_dims[current_level + 1]
|
||
|
|
|
||
|
|
nested_list.extend([deepcopy(template) for _ in range(target_size - current_size)])
|
||
|
|
else:
|
||
|
|
# Create from scratch
|
||
|
|
template_dims = target_dims[current_level + 1 :]
|
||
|
|
template = self._create_empty_nested_structure(template_dims, pad_value)
|
||
|
|
nested_list.extend([deepcopy(template) for _ in range(target_size)])
|
||
|
|
|
||
|
|
# Recursively pad sublists
|
||
|
|
if current_level < len(target_dims) - 1:
|
||
|
|
for i in range(len(nested_list)):
|
||
|
|
if isinstance(nested_list[i], list):
|
||
|
|
nested_list[i] = self._pad_nested_list(nested_list[i], target_dims, current_level + 1, pad_value)
|
||
|
|
|
||
|
|
return nested_list
|
||
|
|
|
||
|
|
def _create_empty_nested_structure(self, dims, pad_value):
|
||
|
|
"""
|
||
|
|
Create an empty nested structure with given dimensions filled with pad_value.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
dims (`list`):
|
||
|
|
The dimensions of the nested structure.
|
||
|
|
pad_value (`int`):
|
||
|
|
The value to fill the structure with.
|
||
|
|
"""
|
||
|
|
if len(dims) == 1:
|
||
|
|
return [pad_value] * dims[0]
|
||
|
|
else:
|
||
|
|
return [self._create_empty_nested_structure(dims[1:], pad_value) for _ in range(dims[0])]
|
||
|
|
|
||
|
|
def _get_nesting_level(self, input_list):
|
||
|
|
"""
|
||
|
|
Get the nesting level of a list structure.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
input_list (`list`):
|
||
|
|
The list to get the nesting level of.
|
||
|
|
"""
|
||
|
|
if isinstance(input_list, list):
|
||
|
|
if len(input_list) == 0:
|
||
|
|
return 1
|
||
|
|
return 1 + self._get_nesting_level(input_list[0])
|
||
|
|
elif isinstance(input_list, (np.ndarray, torch.Tensor)):
|
||
|
|
# For arrays/tensors, the nesting level is the number of dimensions
|
||
|
|
return len(input_list.shape)
|
||
|
|
return 0
|
||
|
|
|
||
|
|
def _validate_single_input(
|
||
|
|
self,
|
||
|
|
data: torch.Tensor | np.ndarray | list,
|
||
|
|
expected_depth: int,
|
||
|
|
input_name: str,
|
||
|
|
expected_format: str,
|
||
|
|
expected_coord_size: int | None = None,
|
||
|
|
) -> list:
|
||
|
|
"""
|
||
|
|
Validate a single input by ensuring proper nesting and raising an error if the input is not valid.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
data (`torch.Tensor`, `np.ndarray`, or `list`):
|
||
|
|
Input data to process.
|
||
|
|
expected_depth (`int`):
|
||
|
|
Expected nesting depth.
|
||
|
|
input_name (`str`):
|
||
|
|
Name of the input for error messages.
|
||
|
|
expected_format (`str`):
|
||
|
|
The expected format of the input.
|
||
|
|
expected_coord_size (`int`, *optional*):
|
||
|
|
Expected coordinate size (2 for points, 4 for boxes, None for labels).
|
||
|
|
.
|
||
|
|
"""
|
||
|
|
if data is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Handle tensors and numpy arrays first
|
||
|
|
if isinstance(data, (torch.Tensor, np.ndarray)):
|
||
|
|
# For tensors/arrays, we can directly check the number of dimensions
|
||
|
|
if data.ndim != expected_depth:
|
||
|
|
raise ValueError(
|
||
|
|
f"Input {input_name} must be a tensor/array with {expected_depth} dimensions. The expected nesting format is {expected_format}. Got {data.ndim} dimensions."
|
||
|
|
)
|
||
|
|
elif expected_coord_size is not None:
|
||
|
|
if data.shape[-1] != expected_coord_size:
|
||
|
|
raise ValueError(
|
||
|
|
f"Input {input_name} must be a tensor/array with {expected_coord_size} as the last dimension, got {data.shape[-1]}."
|
||
|
|
)
|
||
|
|
return self._convert_to_nested_list(data, expected_depth)
|
||
|
|
|
||
|
|
# Handle nested lists
|
||
|
|
if isinstance(data, list):
|
||
|
|
current_depth = self._get_nesting_level(data)
|
||
|
|
if current_depth != expected_depth:
|
||
|
|
raise ValueError(
|
||
|
|
f"Input {input_name} must be a nested list with {expected_depth} levels. The expected nesting format is {expected_format}. Got {current_depth} levels."
|
||
|
|
)
|
||
|
|
return self._convert_to_nested_list(data, expected_depth)
|
||
|
|
|
||
|
|
def _normalize_tensor_coordinates(self, tensor, original_sizes, is_bounding_box=False, preserve_padding=False):
|
||
|
|
"""
|
||
|
|
Helper method to normalize coordinates in a tensor across multiple images.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tensor (`torch.Tensor`):
|
||
|
|
Input tensor with coordinates.
|
||
|
|
original_sizes (`list`):
|
||
|
|
Original image sizes.
|
||
|
|
is_bounding_box (`bool`, *optional*, defaults to `False`):
|
||
|
|
Whether coordinates are bounding boxes.
|
||
|
|
preserve_padding (`bool`, *optional*, defaults to `False`):
|
||
|
|
Whether to preserve padding values (for points).
|
||
|
|
"""
|
||
|
|
if preserve_padding:
|
||
|
|
# For points: avoid normalizing pad values
|
||
|
|
mask = tensor != self.point_pad_value
|
||
|
|
coord_mask = mask.all(dim=-1, keepdim=True)
|
||
|
|
|
||
|
|
for img_idx in range(len(original_sizes)):
|
||
|
|
if img_idx < tensor.shape[0]:
|
||
|
|
original_size = original_sizes[img_idx] if img_idx < len(original_sizes) else original_sizes[0]
|
||
|
|
normalized_coords = self._normalize_coordinates(
|
||
|
|
self.target_size, tensor[img_idx], original_size, is_bounding_box=is_bounding_box
|
||
|
|
)
|
||
|
|
|
||
|
|
if preserve_padding:
|
||
|
|
# Only update non-padded values
|
||
|
|
img_mask = coord_mask[img_idx]
|
||
|
|
tensor[img_idx] = torch.where(
|
||
|
|
img_mask.expand_as(tensor[img_idx]), normalized_coords, tensor[img_idx]
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
tensor[img_idx] = normalized_coords
|
||
|
|
|
||
|
|
def post_process_masks(
|
||
|
|
self,
|
||
|
|
masks,
|
||
|
|
original_sizes,
|
||
|
|
mask_threshold=0.0,
|
||
|
|
binarize=True,
|
||
|
|
max_hole_area=0.0,
|
||
|
|
max_sprinkle_area=0.0,
|
||
|
|
apply_non_overlapping_constraints=False,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Remove padding and upscale masks to the original image size.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
|
||
|
|
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
||
|
|
original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
|
||
|
|
The original sizes of each image before it was resized to the model's expected input shape, in (height,
|
||
|
|
width) format.
|
||
|
|
mask_threshold (`float`, *optional*, defaults to 0.0):
|
||
|
|
Threshold for binarization and post-processing operations.
|
||
|
|
binarize (`bool`, *optional*, defaults to `True`):
|
||
|
|
Whether to binarize the masks.
|
||
|
|
max_hole_area (`float`, *optional*, defaults to 0.0):
|
||
|
|
The maximum area of a hole to fill.
|
||
|
|
max_sprinkle_area (`float`, *optional*, defaults to 0.0):
|
||
|
|
The maximum area of a sprinkle to fill.
|
||
|
|
apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`):
|
||
|
|
Whether to apply non-overlapping constraints to the masks.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
(`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
|
||
|
|
is given by original_size.
|
||
|
|
"""
|
||
|
|
return self.image_processor.post_process_masks(
|
||
|
|
masks,
|
||
|
|
original_sizes,
|
||
|
|
mask_threshold,
|
||
|
|
binarize,
|
||
|
|
max_hole_area,
|
||
|
|
max_sprinkle_area,
|
||
|
|
apply_non_overlapping_constraints,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
@property
|
||
|
|
def model_input_names(self):
|
||
|
|
image_processor_input_names = self.image_processor.model_input_names
|
||
|
|
return list(image_processor_input_names + ["original_sizes"])
|
||
|
|
|
||
|
|
|
||
|
|
__all__ = ["Sam2Processor"]
|