Skip to content

Commit

Permalink
add training script need to optimize data processing
Browse files Browse the repository at this point in the history
  • Loading branch information
ostix360 committed May 29, 2024
1 parent a80cb55 commit 0f8aa5b
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 1 deletion.
4 changes: 3 additions & 1 deletion dataset/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def _construct_ds(self, audio_data: dict) -> None:
with self.session.get(
"https://cdn.sonauto.ai/generations/" + value["audio_url"]
) as resp:
with open(f"./dataset/audio/{nbm_file}.ogg", "wb") as file:
extension = value["audio_url"].split(".")[-1]
with open(f"./dataset/audio/{nbm_file}.{extension}", "wb") as file:

file.write(resp.content)

print(f"Downloaded {nbm_file} songs.", end="\r")
Expand Down
8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@
requests
orjson
jiwer
transformers
torch
torchaudio
datasets
accelerate
bitsandbytes
evaluate
11 changes: 11 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from training.train import Trainer

import training.utils as utils


dataset = utils.gather_dataset("./dataset")
dataset = dataset.train_test_split(test_size=0.1)

trainer = Trainer(dataset)
trainer.process_dataset()
trainer.train()
38 changes: 38 additions & 0 deletions training/collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dataclasses import dataclass
from typing import Any, Dict, Union, List

import torch


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any

def __call__(
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [
{"input_features": feature["input_features"][0]} for feature in features
]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
labels_batch.attention_mask.ne(1), -100
)

# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]

batch["labels"] = labels

return batch
145 changes: 145 additions & 0 deletions training/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
This module contains the Trainer class which is responsible for training whisper on predicting lyrics.
"""

import torch
from datasets import DatasetDict, Audio
from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer

from training.collator import DataCollatorSpeechSeq2SeqWithPadding
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
import evaluate

METRIC = evaluate.load("wer")

NORMALIZER = BasicTextNormalizer()


class Trainer:
"""
A class that represents the trainer for the whisper model.
"""
def __init__(self, dataset: DatasetDict, model_name="openai/whisper-small", ):
"""
The constructor for the Trainer class.
:param dataset: The dataset to train the model on.
"""
self.processor = WhisperProcessor.from_pretrained(
model_name,
task="transcribe"
)
self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
self.dataset = dataset
self.dataset = self.dataset.select_columns(["audio", "lyrics"])
sampling_rate = self.processor.feature_extractor.sampling_rate
self.dataset = self.dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
self.data_collator = DataCollatorSpeechSeq2SeqWithPadding(self.processor)
self.prepare_tokenizer()

def prepare_tokenizer(self) -> None:
"""
A method that adds special tokens i.e. tags to the tokenizer.
:return: None
"""
special_tokens_to_add = []
for i in range(1, 5):
special_tokens_to_add.append(f"[VERSE {i}]")
special_tokens_to_add.append("[CHORUS]")
special_tokens_to_add.append("[BRIDGE]")
self.processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add})
self.model.resize_token_embeddings(len(self.processor.tokenizer))

def process_dataset(self) -> None:
"""
A method that processes the dataset.
:return: None
"""
def prepare_dataset(example):
audio = example["audio"]

example = self.processor(
audio=audio["array"],
sampling_rate=audio["sampling_rate"],
text=example["lyrics"],
)

# compute input length of audio sample in seconds
example["input_length"] = len(audio["array"]) / audio["sampling_rate"]

return example

self.dataset = self.dataset.map(
prepare_dataset,
remove_columns=self.dataset.column_names["train"],
)

def compute_metrics(self, pred):
pred_ids = pred.predictions
label_ids = pred.label_ids

# replace -100 with the pad_token_id
label_ids[label_ids == -100] = self.processor.tokenizer.pad_token_id

# we do not want to group tokens when computing the metrics
pred_str = self.processor.batch_decode(pred_ids, skip_special_tokens=True)
label_str = self.processor.batch_decode(label_ids, skip_special_tokens=True)

# compute orthographic wer
wer_ortho = 100 * METRIC.compute(predictions=pred_str, references=label_str)

# compute normalised WER
pred_str_norm = [NORMALIZER(pred) for pred in pred_str]
label_str_norm = [NORMALIZER(label) for label in label_str]
# filtering step to only evaluate the samples that correspond to non-zero references:
pred_str_norm = [
pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0
]
label_str_norm = [
label_str_norm[i]
for i in range(len(label_str_norm))
if len(label_str_norm[i]) > 0
]

wer = 100 * METRIC.compute(predictions=pred_str_norm, references=label_str_norm)

return {"wer_ortho": wer_ortho, "wer": wer}

def train(self):
"""
A method that trains the model.
:return:
"""
training_args = Seq2SeqTrainingArguments(
output_dir="./train",
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
learning_rate=1e-5,
lr_scheduler_type="linear",
warmup_steps=50,
gradient_checkpointing=False,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
bf16_full_eval=torch.cuda.is_bf16_supported(),
fp16_full_eval=not torch.cuda.is_bf16_supported(),
evaluation_strategy="epoch",
optim="adamw_8bit",
predict_with_generate=True,
generation_max_length=225,
logging_steps=25,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)

trainer = Seq2SeqTrainer(
args=training_args,
model=self.model,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
data_collator=self.data_collator,
# compute_metrics=self.compute_metrics,
tokenizer=self.processor,
)
return trainer.train()


58 changes: 58 additions & 0 deletions training/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
"""
import glob
import warnings
from dataclasses import dataclass
from io import BytesIO
from typing import Optional, Dict, Union

import librosa
import soundfile as sf
from datasets import Dataset, Audio
from tqdm import tqdm


def reformat_audio():
"""Function that reformats the audio files."""
print("Reformatting audio files...")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
bar = tqdm(total=len(glob.glob("./dataset/audio/*")))
for audio in glob.glob("./dataset/audio/*"):
try:
y, sr = sf.read(audio)
if audio.split(".")[-1] != "ogg" and sr != 16000:
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
audio = audio.replace(extension, "ogg")
sf.write(audio, y, sr, format='ogg', subtype='vorbis')
except Exception:
extension = audio.split(".")[-1]
y, sr = librosa.load(audio, sr=None)
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
audio = audio.replace(extension, "ogg")
sf.write(audio, y, sr, format='ogg', subtype='vorbis')
bar.update(1)


def gather_dataset(path: str) -> Dataset:
"""Function that gathers the dataset.
Args:
path (str): The path to the dataset.
Returns:
Dataset: The dataset.
"""
def gen():
i = 0
audio = glob.glob(path + "/audio/*.ogg")
lyrics = glob.glob(path + "/lyrics/*.txt")
for i in range(len(lyrics)):
yield {
"audio": audio[i],
"lyrics": open(lyrics[i], "r").read(),
}
# reformat_audio()
return Dataset.from_generator(gen).cast_column("audio", Audio())

0 comments on commit 0f8aa5b

Please sign in to comment.