diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index f0dca7c..0df1f5d 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -18,7 +18,7 @@ jobs: run: | python -m pip install --upgrade pip pip install pylint - pip install -r requirements.txt + pip install -r .pylint_requirements.txt - name: Analysing the code with pylint run: | pylint $(git ls-files '*.py') --rcfile=.pylintc diff --git a/.gitignore b/.gitignore index 548b01d..8875ff5 100644 --- a/.gitignore +++ b/.gitignore @@ -159,11 +159,11 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ -dataset/audio/ -dataset/lyrics/ -dataset/data.json train/ model/ formated_dataset/ -test.py \ No newline at end of file +test.py + +dataset/ +save.py \ No newline at end of file diff --git a/.pylint_requirements.txt b/.pylint_requirements.txt new file mode 100644 index 0000000..4c94946 --- /dev/null +++ b/.pylint_requirements.txt @@ -0,0 +1,13 @@ +requests +orjson +jiwer +transformers +torch +torchaudio +datasets +accelerate +bitsandbytes +evaluate +librosa +numpy + diff --git a/.pylintc b/.pylintc index 48f344d..d395a63 100644 --- a/.pylintc +++ b/.pylintc @@ -63,7 +63,7 @@ ignore-patterns=^\.# # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis). It # supports qualified module names, as well as Unix pattern matching. -ignored-modules= +ignored-modules=aeneas # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). diff --git a/README.md b/README.md index f35d78f..3922fad 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,49 @@ dataset where `0.wav` corresponds to the audio file and `0.txt` corresponds to the lyrics transcription of the audio file. +## Process the dataset + +To process the dataset, run the following command: + +```bash +python process_dataset.py --clean +``` + +The process will split the audio in chunks of 32 seconds and split the lyrics. + +## Test the model + +Here is an example of how to test the model: + +```py +import librosa +import torch +from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline + + +model_name = "Jour/whisper-small-lyric-finetuned" +audio_file = "PATH_TO_AUDIO_FILE" + +device = "cuda:0" if torch.cuda.is_available() else "cpu" +processor = WhisperProcessor.from_pretrained("openai/whisper-small") +model = WhisperForConditionalGeneration.from_pretrained(model_name) + +pipe = pipeline( + "automatic-speech-recognition", + model=model, + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + max_new_tokens=128, + chunk_length_s=30, + device=device, +) + +sample, _ = librosa.load(audio_file, sr=processor.feature_extractor.sampling_rate) + +prediction = pipe(sample.copy(), batch_size=8)["text"] +print(prediction) +``` + ## License This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. \ No newline at end of file diff --git a/dataset/aeneas_wrapper.py b/dataset/aeneas_wrapper.py new file mode 100644 index 0000000..998a193 --- /dev/null +++ b/dataset/aeneas_wrapper.py @@ -0,0 +1,66 @@ +import json +import re +import tempfile + +from aeneas.tools.execute_task import ExecuteTaskCLI, RuntimeConfiguration + +from dataset.exceptions import AeneasAlignError + + +class AeneasWrapper: + """Wrapper class for Aeneas CLI""" + + def __init__(self) -> None: + self._rconf = RuntimeConfiguration() + self._rconf[RuntimeConfiguration.MFCC_MASK_NONSPEECH] = True + self._rconf[RuntimeConfiguration.MFCC_MASK_NONSPEECH_L3] = True + self._rconf[RuntimeConfiguration.TTS_CACHE] = True + self._rconf.set_granularity(3) + + def aeneas_cli_exec(self, audio_path: str, lyric_path: str) -> dict: + """Align lyrics with audio + + Args: + audio_path (str): the path to the audio file + lyric_path (str): the path to the lyric file + + Raises: + AeneasAlignError: if Aeneas fails to align lyrics + + Returns: + dict: a dictionary containing the alignment data + """ + + tmp_dir = tempfile.mkdtemp() + + with open(lyric_path, "r", encoding="utf-8") as f: + lyric = f.read() + + # remove all text between [] + lyric = re.sub(r"\[.*?\]", "\n", lyric) + + # remove when more than 2 new lines + lyric = re.sub(r"\n{1,}", "\n", lyric).strip() + + lyric = lyric.replace(" ", "\n") + + with open(f"{tmp_dir}/lyric.txt", "w", encoding="utf-8") as f: + f.write(lyric) + + args = [ + "dummy", + audio_path, + f"{tmp_dir}/lyric.txt", + "task_language=en|is_text_type=plain|os_task_file_format=json", + f"{tmp_dir}/lyric.json", + ] + + exit_code = ExecuteTaskCLI(use_sys=False, rconf=self._rconf).run(arguments=args) + + if exit_code != 0: + raise AeneasAlignError("Aeneas failed to align lyrics") + + with open(f"{tmp_dir}/lyric.json", "r", encoding="utf-8") as f: + data = json.load(f) + + return data diff --git a/dataset/exceptions.py b/dataset/exceptions.py new file mode 100644 index 0000000..374a290 --- /dev/null +++ b/dataset/exceptions.py @@ -0,0 +1,2 @@ +class AeneasAlignError(Exception): + """Raised when Aeneas fails to align lyrics""" diff --git a/dataset/process.py b/dataset/process.py new file mode 100644 index 0000000..7f017ad --- /dev/null +++ b/dataset/process.py @@ -0,0 +1,189 @@ +import os +import shutil +from typing import List + +from pydub import AudioSegment + +import dataset.exceptions +from dataset.aeneas_wrapper import AeneasWrapper + + +class DatasetProcess: + """Class to process the dataset""" + + def __init__( + self, + lyric_path: str, + audio_path: str, + sample_rate: int = None, + export_path: str = None, + clean: bool = False, + ): + """Constructor to initialize the DatasetProcess class + + Args: + lyric_path (str): the path to the lyrics folder + audio_path (str): the path to the audio folder + sample_rate (int, optional): the sample rate of the audio. Defaults to None. + export_path (str, optional): the path to export data. Defaults to None. + clean (bool, optional): remove all data in the export path. Defaults to False. + """ + + self.lyric_path = lyric_path + self.audio_path = audio_path + self.export_path = export_path + self.sample_rate = sample_rate + + if clean and self.export_path: + self.remove_export_folder() + + self.create_export_folder() + + self.aeneas = AeneasWrapper() + + def create_export_folder(self) -> None: + """Method to create the export folder""" + + if not os.path.exists(self.export_path): + os.makedirs(self.export_path) + + if not os.path.exists(f"{self.export_path}/audio"): + os.makedirs(f"{self.export_path}/audio") + + if not os.path.exists(f"{self.export_path}/lyrics"): + os.makedirs(f"{self.export_path}/lyrics") + + def remove_export_folder(self) -> None: + """Method to remove the export folder""" + + if os.path.exists(self.export_path): + shutil.rmtree(self.export_path) + + def _split_audio( + self, audio_path: str, split_windows: int = 32 + ) -> List[AudioSegment]: + """Method to split audio into 32 seconds segments + + Args: + audio_path (str): the path to the audio file + split_windows (int, optional): the size of the split window in seconds. Defaults to 32. + + Returns: + list: a list of AudioSegment that contain audio split into 32 seconds segments + """ + + audio = AudioSegment.from_file(audio_path) + segments = [] + + for i in range(0, len(audio), split_windows * 1000): + segments.append(audio[i : i + split_windows * 1000]) + + return segments + + def _split_lyric( + self, lyric_path: str, alignement: dict, split_windows: int = 32 + ) -> List[str]: + """Method to split audio into 32 seconds segments with the corresponding lyrics + + Args: + lyric_path (str): the path to the lyric file + alignement (dict): the alignment data + split_windows (int, optional): the size of the split window in seconds. Defaults to 32. + + Returns: + list: a list of list that contain lyrics split into 32 seconds segments + """ + + with open(lyric_path, "r", encoding="utf-8") as f: + lyric = f.read() + + segments = [] + start_idx = 0 + end_idx = 0 + + for fragment in alignement["fragments"]: + end_idx = lyric.find(fragment["lines"][0], end_idx) + windows = (len(segments) + 1) * split_windows + + if float(fragment["begin"]) > windows: + segments.append(lyric[start_idx:end_idx]) + start_idx = end_idx + + segments.append(lyric[start_idx:]) + + return segments + + def _export_audio(self, audios: List[AudioSegment], file_name: str) -> None: + """Method to export audio segments to .wav format + + Args: + audios (List[AudioSegment]): a list of AudioSegment + file_name (str): the name of the file + """ + + for i, audio in enumerate(audios): + path = f"{self.audio_path}/{file_name}_{i}.wav" + + if self.export_path: + path = f"{self.export_path}/audio/{file_name}_{i}.wav" + + if self.sample_rate: + audio = audio.set_frame_rate(self.sample_rate) + + audio.export(path, format="wav") + + def _export_lyric(self, lyrics: List[str], file_name: str) -> None: + """Method to export lyrics segments to .txt format + + Args: + lyrics (List[str]): a list of lyrics + file_name (str): the name of the file + """ + + for i, lyric in enumerate(lyrics): + path = f"{self.lyric_path}/{file_name}_{i}.txt" + + if self.export_path: + path = f"{self.export_path}/lyrics/{file_name}_{i}.txt" + + with open(path, "w", encoding="utf-8") as f: + f.write(lyric) + + def process(self, remove: bool = False) -> None: + """Method to process the dataset + 1. Align lyrics with audio + 2. Split audio into 32 seconds segments + 3. Save the segments to the dataset/audio/processed folder in .wav format + + Args: + remove (bool, optional): remove the processed file. Defaults to False. + """ + + nbm_files = len(os.listdir(self.audio_path)) + for i, audio_f in enumerate(os.listdir(self.audio_path)): + if not audio_f.endswith(".ogg") and not audio_f.endswith(".mp4"): + continue + + audio_path = os.path.join(self.audio_path, audio_f) + lyric_path = os.path.join(self.lyric_path, audio_f.split(".")[0] + ".txt") + + try: + alignement = self.aeneas.aeneas_cli_exec(audio_path, lyric_path) + except dataset.exceptions.AeneasAlignError as e: + print(f"Failed to align {audio_f}: {e}") + continue + + lyric_segments = self._split_lyric(lyric_path, alignement) + audio_segments = self._split_audio(audio_path) + + # save the audio segments and the lyrics + self._export_audio(audio_segments, audio_f.split(".")[0]) + self._export_lyric(lyric_segments, audio_f.split(".")[0]) + + print( + f"Processed {i}/ {nbm_files} - {round(i/nbm_files*100, 2)}%", end="\r" + ) + + if remove: + os.remove(lyric_path) + os.remove(audio_path) diff --git a/download_dataset.py b/download_dataset.py index bbef722..085cc1f 100644 --- a/download_dataset.py +++ b/download_dataset.py @@ -4,7 +4,7 @@ parser = argparse.ArgumentParser( - description="Download images from Sonauto dataset", + description="Download music from Sonauto API", ) parser.add_argument("--num_images", type=int, default=10000) parser.add_argument("--clean", type=bool, default=True) diff --git a/process_dataset.py b/process_dataset.py new file mode 100644 index 0000000..16fb937 --- /dev/null +++ b/process_dataset.py @@ -0,0 +1,24 @@ +import argparse + +from dataset.process import DatasetProcess + + +parser = argparse.ArgumentParser( + description="Process the dataset", +) +parser.add_argument("--audio_path", type=str, default="dataset/audio") +parser.add_argument("--lyric_path", type=str, default="dataset/lyrics") +parser.add_argument("--export_path", type=str, default="dataset/export") +parser.add_argument("--sample_rate", type=int, default=None) +parser.add_argument("--clean", type=bool, default=False) + +args = parser.parse_args() +process = DatasetProcess( + lyric_path=args.lyric_path, + audio_path=args.audio_path, + sample_rate=args.sample_rate, + export_path=args.export_path, + clean=args.clean, +) + +process.process() diff --git a/requirements.txt b/requirements.txt index 72a6856..ae1418e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,6 @@ accelerate bitsandbytes evaluate librosa +numpy +aeneas +pydub diff --git a/train.py b/train.py index 8332e62..8ded7e9 100644 --- a/train.py +++ b/train.py @@ -1,21 +1,11 @@ -from datasets import DatasetDict - -from training.train import Trainer - from training import utils +from training.train import Trainer -LOAD_DATASET = True +dataset = utils.gather_dataset("./dataset/export") +dataset = dataset.train_test_split(test_size=0.1) -if LOAD_DATASET: - dataset = utils.gather_dataset("./dataset") -else: - dataset = DatasetDict.load_from_disk("./formated_dataset") trainer = Trainer(dataset) -if LOAD_DATASET: - for i in range(dataset.num_rows//1000): - dataset = trainer.process_dataset(dataset, i) - dataset.save_to_disk(f"./formated_dataset_{i}") - +dataset = trainer.process_dataset(dataset) trainer.train() trainer.model.save_pretrained("./model") trainer.processor.save_pretrained("./model") diff --git a/training/train.py b/training/train.py index 6601495..cc3f7b9 100644 --- a/training/train.py +++ b/training/train.py @@ -1,6 +1,7 @@ """ This module contains the Trainer class which is responsible for training whisper on predicting lyrics. """ + import warnings import evaluate @@ -8,7 +9,12 @@ import numpy as np import torch from datasets import Dataset -from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer +from transformers import ( + WhisperProcessor, + WhisperForConditionalGeneration, + Seq2SeqTrainingArguments, + Seq2SeqTrainer, +) from transformers.models.whisper.english_normalizer import BasicTextNormalizer from training.collator import DataCollatorSpeechSeq2SeqWithPadding @@ -22,17 +28,19 @@ class Trainer: """ A class that represents the trainer for the whisper model. """ - def __init__(self, dataset=None, model_name="openai/whisper-small", ): + + def __init__( + self, + dataset=None, + model_name="openai/whisper-small", + ): """ The constructor for the Trainer class. The dataset is optional and can be added later with the method process_dataset. The dataset should be formated and already mapped to the columns "audio" and "lyrics" and ready for training. :param dataset: The dataset to train the model on. """ - self.processor = WhisperProcessor.from_pretrained( - model_name, - task="transcribe" - ) + self.processor = WhisperProcessor.from_pretrained(model_name, task="transcribe") self.model = WhisperForConditionalGeneration.from_pretrained(model_name) self.dataset = dataset self.data_collator = DataCollatorSpeechSeq2SeqWithPadding(self.processor) @@ -48,7 +56,9 @@ def prepare_tokenizer(self) -> None: special_tokens_to_add.append(f"[VERSE {i}]") special_tokens_to_add.append("[CHORUS]") special_tokens_to_add.append("[BRIDGE]") - self.processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add}) + self.processor.tokenizer.add_special_tokens( + {"additional_special_tokens": special_tokens_to_add} + ) self.model.resize_token_embeddings(len(self.processor.tokenizer)) def process_dataset(self, dataset, chunk_id) -> Dataset: @@ -56,6 +66,7 @@ def process_dataset(self, dataset, chunk_id) -> Dataset: A method that processes the dataset. :return: None """ + def prepare_dataset(example): target_sr = self.processor.feature_extractor.sampling_rate with warnings.catch_warnings(): @@ -110,7 +121,9 @@ def compute_metrics(self, pred): label_str_norm = [NORMALIZER(label) for label in label_str] # filtering step to only evaluate the samples that correspond to non-zero references: pred_str_norm = [ - pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0 + pred_str_norm[i] + for i in range(len(pred_str_norm)) + if len(label_str_norm[i]) > 0 ] label_str_norm = [ label_str_norm[i] @@ -159,3 +172,12 @@ def train(self): tokenizer=self.processor, ) return trainer.train() + + def save_model(self, path: str) -> None: + """ + A method that saves the model. + :param path: The path to save the model. + :return: None + """ + + self.model.save_pretrained(path) diff --git a/training/utils.py b/training/utils.py index 10d96fa..00392fa 100644 --- a/training/utils.py +++ b/training/utils.py @@ -18,7 +18,6 @@ def gather_dataset(path: str) -> Dataset: """ def gen(): - i = 0 # use to regenerate the dataset audios = glob.glob(path + "/audio/*") lyrics = glob.glob(path + "/lyrics/*.txt") for audio, lyric in zip(audios, lyrics):