You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

70 lines
3.2 KiB

import torch
from ...utils import TensorType
from ..superglue.image_processing_superglue_fast import SuperGlueImageProcessorFast
from .modeling_efficientloftr import EfficientLoFTRKeypointMatchingOutput
class EfficientLoFTRImageProcessorFast(SuperGlueImageProcessorFast):
def post_process_keypoint_matching(
self,
outputs: "EfficientLoFTRKeypointMatchingOutput",
target_sizes: TensorType | list[tuple],
threshold: float = 0.0,
) -> list[dict[str, torch.Tensor]]:
"""
Converts the raw output of [`EfficientLoFTRKeypointMatchingOutput`] into lists of keypoints, scores and descriptors
with coordinates absolute to the original image sizes.
Args:
outputs ([`EfficientLoFTRKeypointMatchingOutput`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*):
Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the
target size `(height, width)` of each image in the batch. This must be the original image size (before
any processing).
threshold (`float`, *optional*, defaults to 0.0):
Threshold to filter out the matches with low scores.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
of the pair, the matching scores and the matching indices.
"""
if outputs.matches.shape[0] != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
if not all(len(target_size) == 2 for target_size in target_sizes):
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
if isinstance(target_sizes, list):
image_pair_sizes = torch.tensor(target_sizes, device=outputs.matches.device)
else:
if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
raise ValueError(
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
)
image_pair_sizes = target_sizes
keypoints = outputs.keypoints.clone()
keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
keypoints = keypoints.to(torch.int32)
results = []
for keypoints_pair, matches, scores in zip(keypoints, outputs.matches, outputs.matching_scores):
# Filter out matches with low scores
valid_matches = torch.logical_and(scores > threshold, matches > -1)
matched_keypoints0 = keypoints_pair[0][valid_matches[0]]
matched_keypoints1 = keypoints_pair[1][valid_matches[1]]
matching_scores = scores[0][valid_matches[0]]
results.append(
{
"keypoints0": matched_keypoints0,
"keypoints1": matched_keypoints1,
"matching_scores": matching_scores,
}
)
return results
__all__ = ["EfficientLoFTRImageProcessorFast"]