diff --git a/dataset/process.py b/dataset/process.py index 0b4b070..7f017ad 100644 --- a/dataset/process.py +++ b/dataset/process.py @@ -34,7 +34,7 @@ def __init__( self.export_path = export_path self.sample_rate = sample_rate - if clean and self.export_path and os.path.exists(self.export_path): + if clean and self.export_path: self.remove_export_folder() self.create_export_folder() diff --git a/train.py b/train.py index 7ed6691..4c4b93f 100644 --- a/train.py +++ b/train.py @@ -1,19 +1,10 @@ -from datasets import DatasetDict - -from training.train import Trainer - from training import utils +from training.train import Trainer -LOAD_DATASET = True +dataset = utils.gather_dataset("./dataset/export") +dataset = dataset.train_test_split(test_size=0.1) -if LOAD_DATASET: - dataset = utils.gather_dataset("./dataset") - dataset = dataset.train_test_split(test_size=0.1) -else: - dataset = DatasetDict.load_from_disk("./formated_dataset") trainer = Trainer(dataset) -if LOAD_DATASET: - dataset = trainer.process_dataset(dataset) - dataset.save_to_disk("./formated_dataset") +dataset = trainer.process_dataset(dataset) trainer.train() diff --git a/training/train.py b/training/train.py index 2e5f6b9..fd6cc40 100644 --- a/training/train.py +++ b/training/train.py @@ -1,6 +1,7 @@ """ This module contains the Trainer class which is responsible for training whisper on predicting lyrics. """ + import warnings import evaluate @@ -8,7 +9,12 @@ import numpy as np import torch from datasets import Dataset -from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer +from transformers import ( + WhisperProcessor, + WhisperForConditionalGeneration, + Seq2SeqTrainingArguments, + Seq2SeqTrainer, +) from transformers.models.whisper.english_normalizer import BasicTextNormalizer from training.collator import DataCollatorSpeechSeq2SeqWithPadding @@ -22,17 +28,19 @@ class Trainer: """ A class that represents the trainer for the whisper model. """ - def __init__(self, dataset=None, model_name="openai/whisper-small", ): + + def __init__( + self, + dataset=None, + model_name="openai/whisper-small", + ): """ The constructor for the Trainer class. The dataset is optional and can be added later with the method process_dataset. The dataset should be formated and already mapped to the columns "audio" and "lyrics" and ready for training. :param dataset: The dataset to train the model on. """ - self.processor = WhisperProcessor.from_pretrained( - model_name, - task="transcribe" - ) + self.processor = WhisperProcessor.from_pretrained(model_name, task="transcribe") self.model = WhisperForConditionalGeneration.from_pretrained(model_name) self.dataset = dataset self.data_collator = DataCollatorSpeechSeq2SeqWithPadding(self.processor) @@ -48,7 +56,9 @@ def prepare_tokenizer(self) -> None: special_tokens_to_add.append(f"[VERSE {i}]") special_tokens_to_add.append("[CHORUS]") special_tokens_to_add.append("[BRIDGE]") - self.processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add}) + self.processor.tokenizer.add_special_tokens( + {"additional_special_tokens": special_tokens_to_add} + ) self.model.resize_token_embeddings(len(self.processor.tokenizer)) def process_dataset(self, dataset) -> Dataset: @@ -56,6 +66,7 @@ def process_dataset(self, dataset) -> Dataset: A method that processes the dataset. :return: None """ + def prepare_dataset(example): target_sr = self.processor.feature_extractor.sampling_rate with warnings.catch_warnings(): @@ -110,7 +121,9 @@ def compute_metrics(self, pred): label_str_norm = [NORMALIZER(label) for label in label_str] # filtering step to only evaluate the samples that correspond to non-zero references: pred_str_norm = [ - pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0 + pred_str_norm[i] + for i in range(len(pred_str_norm)) + if len(label_str_norm[i]) > 0 ] label_str_norm = [ label_str_norm[i] diff --git a/training/utils.py b/training/utils.py index 10d96fa..00392fa 100644 --- a/training/utils.py +++ b/training/utils.py @@ -18,7 +18,6 @@ def gather_dataset(path: str) -> Dataset: """ def gen(): - i = 0 # use to regenerate the dataset audios = glob.glob(path + "/audio/*") lyrics = glob.glob(path + "/lyrics/*.txt") for audio, lyric in zip(audios, lyrics):