You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

127 lines
5.1 KiB

# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processor class for Pop2Piano."""
import os
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin
from ...tokenization_python import BatchEncoding, PaddingStrategy, TruncationStrategy
from ...utils import TensorType, auto_docstring
from ...utils.import_utils import requires
@requires(backends=("essentia", "librosa", "pretty_midi", "scipy", "torch"))
@auto_docstring
class Pop2PianoProcessor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
@auto_docstring
def __call__(
self,
audio: np.ndarray | list[float] | list[np.ndarray] = None,
sampling_rate: int | list[int] | None = None,
steps_per_beat: int = 2,
resample: bool | None = True,
notes: list | TensorType = None,
padding: bool | str | PaddingStrategy = False,
truncation: bool | str | TruncationStrategy = None,
max_length: int | None = None,
pad_to_multiple_of: int | None = None,
verbose: bool = True,
**kwargs,
) -> BatchFeature | BatchEncoding:
# Since Feature Extractor needs both audio and sampling_rate and tokenizer needs both token_ids and
# feature_extractor_output, we must check for both.
r"""
sampling_rate (`int` or `list[int]`, *optional*):
The sampling rate of the input audio in Hz. This should match the sampling rate used by the feature
extractor. If not provided, the default sampling rate from the processor configuration will be used.
steps_per_beat (`int`, *optional*, defaults to `2`):
The number of time steps per musical beat. This parameter controls the temporal resolution of the
musical representation. A higher value provides finer temporal granularity but increases the sequence
length. Used when processing audio to extract musical features.
notes (`list` or `TensorType`, *optional*):
Pre-extracted musical notes in MIDI format. When provided, the processor skips audio feature extraction
and directly processes the notes through the tokenizer. Each note should be represented as a list or
tensor containing pitch, velocity, and timing information.
"""
if (audio is None and sampling_rate is None) and (notes is None):
raise ValueError(
"You have to specify at least audios and sampling_rate in order to use feature extractor or "
"notes to use the tokenizer part."
)
if audio is not None and sampling_rate is not None:
inputs = self.feature_extractor(
audio=audio,
sampling_rate=sampling_rate,
steps_per_beat=steps_per_beat,
resample=resample,
**kwargs,
)
if notes is not None:
encoded_token_ids = self.tokenizer(
notes=notes,
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
if notes is None:
return inputs
elif audio is None or sampling_rate is None:
return encoded_token_ids
else:
inputs["token_ids"] = encoded_token_ids["token_ids"]
return inputs
def batch_decode(
self,
token_ids,
feature_extractor_output: BatchFeature,
return_midi: bool = True,
) -> BatchEncoding:
"""
This method uses [`Pop2PianoTokenizer.batch_decode`] method to convert model generated token_ids to midi_notes.
Please refer to the docstring of the above two methods for more information.
"""
return self.tokenizer.batch_decode(
token_ids=token_ids, feature_extractor_output=feature_extractor_output, return_midi=return_midi
)
def save_pretrained(self, save_directory, **kwargs):
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
return super().save_pretrained(save_directory, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(*args)
__all__ = ["Pop2PianoProcessor"]