Skip to content

Commit

Permalink
[update] change training script & opti
Browse files Browse the repository at this point in the history
  • Loading branch information
Jourdelune committed Jun 2, 2024
1 parent 234e0c4 commit b8288f9
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 23 deletions.
2 changes: 1 addition & 1 deletion dataset/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
self.export_path = export_path
self.sample_rate = sample_rate

if clean and self.export_path and os.path.exists(self.export_path):
if clean and self.export_path:
self.remove_export_folder()

self.create_export_folder()
Expand Down
17 changes: 4 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
from datasets import DatasetDict

from training.train import Trainer

from training import utils
from training.train import Trainer

LOAD_DATASET = True
dataset = utils.gather_dataset("./dataset/export")
dataset = dataset.train_test_split(test_size=0.1)

if LOAD_DATASET:
dataset = utils.gather_dataset("./dataset")
dataset = dataset.train_test_split(test_size=0.1)
else:
dataset = DatasetDict.load_from_disk("./formated_dataset")
trainer = Trainer(dataset)
if LOAD_DATASET:
dataset = trainer.process_dataset(dataset)
dataset.save_to_disk("./formated_dataset")
dataset = trainer.process_dataset(dataset)

trainer.train()
29 changes: 21 additions & 8 deletions training/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""
This module contains the Trainer class which is responsible for training whisper on predicting lyrics.
"""

import warnings

import evaluate
import librosa
import numpy as np
import torch
from datasets import Dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
)
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

from training.collator import DataCollatorSpeechSeq2SeqWithPadding
Expand All @@ -22,17 +28,19 @@ class Trainer:
"""
A class that represents the trainer for the whisper model.
"""
def __init__(self, dataset=None, model_name="openai/whisper-small", ):

def __init__(
self,
dataset=None,
model_name="openai/whisper-small",
):
"""
The constructor for the Trainer class.
The dataset is optional and can be added later with the method process_dataset.
The dataset should be formated and already mapped to the columns "audio" and "lyrics" and ready for training.
:param dataset: The dataset to train the model on.
"""
self.processor = WhisperProcessor.from_pretrained(
model_name,
task="transcribe"
)
self.processor = WhisperProcessor.from_pretrained(model_name, task="transcribe")
self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
self.dataset = dataset
self.data_collator = DataCollatorSpeechSeq2SeqWithPadding(self.processor)
Expand All @@ -48,14 +56,17 @@ def prepare_tokenizer(self) -> None:
special_tokens_to_add.append(f"[VERSE {i}]")
special_tokens_to_add.append("[CHORUS]")
special_tokens_to_add.append("[BRIDGE]")
self.processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add})
self.processor.tokenizer.add_special_tokens(
{"additional_special_tokens": special_tokens_to_add}
)
self.model.resize_token_embeddings(len(self.processor.tokenizer))

def process_dataset(self, dataset) -> Dataset:
"""
A method that processes the dataset.
:return: None
"""

def prepare_dataset(example):
target_sr = self.processor.feature_extractor.sampling_rate
with warnings.catch_warnings():
Expand Down Expand Up @@ -110,7 +121,9 @@ def compute_metrics(self, pred):
label_str_norm = [NORMALIZER(label) for label in label_str]
# filtering step to only evaluate the samples that correspond to non-zero references:
pred_str_norm = [
pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0
pred_str_norm[i]
for i in range(len(pred_str_norm))
if len(label_str_norm[i]) > 0
]
label_str_norm = [
label_str_norm[i]
Expand Down
1 change: 0 additions & 1 deletion training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def gather_dataset(path: str) -> Dataset:
"""

def gen():
i = 0 # use to regenerate the dataset
audios = glob.glob(path + "/audio/*")
lyrics = glob.glob(path + "/lyrics/*.txt")
for audio, lyric in zip(audios, lyrics):
Expand Down

0 comments on commit b8288f9

Please sign in to comment.