Skip to content

Commit

Permalink
update data processing
Browse files Browse the repository at this point in the history
  • Loading branch information
ostix360 committed May 29, 2024
1 parent 0f8aa5b commit 3aea584
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 54 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,12 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

dataset/audio/
dataset/lyrics/
dataset/data.json
train/
formated_dataset/

test.py
2 changes: 1 addition & 1 deletion download_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
parser = argparse.ArgumentParser(
description="Download images from Sonauto dataset",
)
parser.add_argument("--num_images", type=int, default=10)
parser.add_argument("--num_images", type=int, default=1000)
parser.add_argument("--clean", type=bool, default=True)

args = parser.parse_args()
Expand Down
16 changes: 12 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from datasets import DatasetDict

from training.train import Trainer

import training.utils as utils

LOAD_DATASET = True

dataset = utils.gather_dataset("./dataset")
dataset = dataset.train_test_split(test_size=0.1)

if LOAD_DATASET:
dataset = utils.gather_dataset("./dataset")
dataset = dataset.train_test_split(test_size=0.1)
else:
dataset = DatasetDict.load_from_disk("./formated_dataset")
trainer = Trainer(dataset)
trainer.process_dataset()
if LOAD_DATASET:
dataset = trainer.process_dataset(dataset)
dataset.save_to_disk("./formated_dataset")

trainer.train()
11 changes: 10 additions & 1 deletion training/collator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
This file contains the data collator for the Speech2Text model.
"""

from dataclasses import dataclass
from typing import Any, Dict, Union, List

Expand All @@ -11,6 +15,11 @@ class DataCollatorSpeechSeq2SeqWithPadding:
def __call__(
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
"""
This method pads the input features and the labels to the maximum length in the batch and return it.
:param features: The features to pad.
:return: The padded features.
"""
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [
Expand All @@ -35,4 +44,4 @@ def __call__(

batch["labels"] = labels

return batch
return batch
47 changes: 33 additions & 14 deletions training/train.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""
This module contains the Trainer class which is responsible for training whisper on predicting lyrics.
"""
import warnings

import librosa
import numpy as np
import torch
from datasets import DatasetDict, Audio
from datasets import Audio, Dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer

from training import utils
from training.collator import DataCollatorSpeechSeq2SeqWithPadding
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
import evaluate
Expand All @@ -19,9 +23,11 @@ class Trainer:
"""
A class that represents the trainer for the whisper model.
"""
def __init__(self, dataset: DatasetDict, model_name="openai/whisper-small", ):
def __init__(self, dataset=None, model_name="openai/whisper-small", ):
"""
The constructor for the Trainer class.
The dataset is optional and can be added later with the method process_dataset.
The dataset should be formated and already mapped to the columns "audio" and "lyrics" and ready for training.
:param dataset: The dataset to train the model on.
"""
self.processor = WhisperProcessor.from_pretrained(
Expand All @@ -30,9 +36,6 @@ def __init__(self, dataset: DatasetDict, model_name="openai/whisper-small", ):
)
self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
self.dataset = dataset
self.dataset = self.dataset.select_columns(["audio", "lyrics"])
sampling_rate = self.processor.feature_extractor.sampling_rate
self.dataset = self.dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
self.data_collator = DataCollatorSpeechSeq2SeqWithPadding(self.processor)
self.prepare_tokenizer()

Expand All @@ -49,31 +52,46 @@ def prepare_tokenizer(self) -> None:
self.processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add})
self.model.resize_token_embeddings(len(self.processor.tokenizer))

def process_dataset(self) -> None:
def process_dataset(self, dataset) -> Dataset:
"""
A method that processes the dataset.
:return: None
"""
def prepare_dataset(example):
audio = example["audio"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
audio, sr = librosa.load(example["audio"], sr=None)
audio = librosa.resample(
np.asarray(audio),
orig_sr=sr,
target_sr=self.processor.feature_extractor.sampling_rate
)

example = self.processor(
audio=audio["array"],
sampling_rate=audio["sampling_rate"],
audio=audio,
sampling_rate=self.processor.feature_extractor.sampling_rate,
text=example["lyrics"],
)

# compute input length of audio sample in seconds
example["input_length"] = len(audio["array"]) / audio["sampling_rate"]
example["input_length"] = len(audio) / sr

return example

self.dataset = self.dataset.map(
self.dataset = dataset.map(
prepare_dataset,
remove_columns=self.dataset.column_names["train"],
remove_columns=dataset.column_names["train"],
num_proc=1,
)
return self.dataset

def compute_metrics(self, pred):
"""
A method that computes the metrics.
:param pred: The predictions of the model.
:return: The metrics.
"""
pred_ids = pred.predictions
label_ids = pred.label_ids

Expand Down Expand Up @@ -111,8 +129,9 @@ def train(self):
"""
training_args = Seq2SeqTrainingArguments(
output_dir="./train",
per_device_train_batch_size=4,
per_device_train_batch_size=8,
per_device_eval_batch_size=4,
num_train_epochs=1,
learning_rate=1e-5,
lr_scheduler_type="linear",
warmup_steps=50,
Expand All @@ -124,7 +143,7 @@ def train(self):
evaluation_strategy="epoch",
optim="adamw_8bit",
predict_with_generate=True,
generation_max_length=225,
generation_max_length=512,
logging_steps=25,
metric_for_best_model="wer",
greater_is_better=False,
Expand Down
51 changes: 18 additions & 33 deletions training/utils.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,25 @@
"""
This module contains the utilities for the training module.
"""
import glob
import warnings
from dataclasses import dataclass
from io import BytesIO
from typing import Optional, Dict, Union
from typing import Tuple

import librosa
import soundfile as sf
from datasets import Dataset, Audio
from tqdm import tqdm


def reformat_audio():
"""Function that reformats the audio files."""
print("Reformatting audio files...")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
bar = tqdm(total=len(glob.glob("./dataset/audio/*")))
for audio in glob.glob("./dataset/audio/*"):
try:
y, sr = sf.read(audio)
if audio.split(".")[-1] != "ogg" and sr != 16000:
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
audio = audio.replace(extension, "ogg")
sf.write(audio, y, sr, format='ogg', subtype='vorbis')
except Exception:
extension = audio.split(".")[-1]
y, sr = librosa.load(audio, sr=None)
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
audio = audio.replace(extension, "ogg")
sf.write(audio, y, sr, format='ogg', subtype='vorbis')
bar.update(1)
import torchaudio
from datasets import Dataset
from numpy import ndarray


def reformat_audio(path: str) -> Tuple[ndarray, int]:
"""
Function that reformats the audio files.
:param path: The path to the audio file.
:return: The audio array and the sampling rate.
"""

return


def gather_dataset(path: str) -> Dataset:
Expand All @@ -46,13 +32,12 @@ def gather_dataset(path: str) -> Dataset:
Dataset: The dataset.
"""
def gen():
i = 0
audio = glob.glob(path + "/audio/*.ogg")
i = 1 # use to
audio = glob.glob(path + "/audio/*")
lyrics = glob.glob(path + "/lyrics/*.txt")
for i in range(len(lyrics)):
yield {
"audio": audio[i],
"lyrics": open(lyrics[i], "r").read(),
}
# reformat_audio()
return Dataset.from_generator(gen).cast_column("audio", Audio())
return Dataset.from_generator(gen)

0 comments on commit 3aea584

Please sign in to comment.