-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add training script need to optimize data processing
- Loading branch information
Showing
6 changed files
with
263 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,10 @@ | ||
requests | ||
orjson | ||
jiwer | ||
transformers | ||
torch | ||
torchaudio | ||
datasets | ||
accelerate | ||
bitsandbytes | ||
evaluate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from training.train import Trainer | ||
|
||
import training.utils as utils | ||
|
||
|
||
dataset = utils.gather_dataset("./dataset") | ||
dataset = dataset.train_test_split(test_size=0.1) | ||
|
||
trainer = Trainer(dataset) | ||
trainer.process_dataset() | ||
trainer.train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, Union, List | ||
|
||
import torch | ||
|
||
|
||
@dataclass | ||
class DataCollatorSpeechSeq2SeqWithPadding: | ||
processor: Any | ||
|
||
def __call__( | ||
self, features: List[Dict[str, Union[List[int], torch.Tensor]]] | ||
) -> Dict[str, torch.Tensor]: | ||
# split inputs and labels since they have to be of different lengths and need different padding methods | ||
# first treat the audio inputs by simply returning torch tensors | ||
input_features = [ | ||
{"input_features": feature["input_features"][0]} for feature in features | ||
] | ||
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") | ||
|
||
# get the tokenized label sequences | ||
label_features = [{"input_ids": feature["labels"]} for feature in features] | ||
# pad the labels to max length | ||
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") | ||
|
||
# replace padding with -100 to ignore loss correctly | ||
labels = labels_batch["input_ids"].masked_fill( | ||
labels_batch.attention_mask.ne(1), -100 | ||
) | ||
|
||
# if bos token is appended in previous tokenization step, | ||
# cut bos token here as it's append later anyways | ||
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): | ||
labels = labels[:, 1:] | ||
|
||
batch["labels"] = labels | ||
|
||
return batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
""" | ||
This module contains the Trainer class which is responsible for training whisper on predicting lyrics. | ||
""" | ||
|
||
import torch | ||
from datasets import DatasetDict, Audio | ||
from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer | ||
|
||
from training.collator import DataCollatorSpeechSeq2SeqWithPadding | ||
from transformers.models.whisper.english_normalizer import BasicTextNormalizer | ||
import evaluate | ||
|
||
METRIC = evaluate.load("wer") | ||
|
||
NORMALIZER = BasicTextNormalizer() | ||
|
||
|
||
class Trainer: | ||
""" | ||
A class that represents the trainer for the whisper model. | ||
""" | ||
def __init__(self, dataset: DatasetDict, model_name="openai/whisper-small", ): | ||
""" | ||
The constructor for the Trainer class. | ||
:param dataset: The dataset to train the model on. | ||
""" | ||
self.processor = WhisperProcessor.from_pretrained( | ||
model_name, | ||
task="transcribe" | ||
) | ||
self.model = WhisperForConditionalGeneration.from_pretrained(model_name) | ||
self.dataset = dataset | ||
self.dataset = self.dataset.select_columns(["audio", "lyrics"]) | ||
sampling_rate = self.processor.feature_extractor.sampling_rate | ||
self.dataset = self.dataset.cast_column("audio", Audio(sampling_rate=sampling_rate)) | ||
self.data_collator = DataCollatorSpeechSeq2SeqWithPadding(self.processor) | ||
self.prepare_tokenizer() | ||
|
||
def prepare_tokenizer(self) -> None: | ||
""" | ||
A method that adds special tokens i.e. tags to the tokenizer. | ||
:return: None | ||
""" | ||
special_tokens_to_add = [] | ||
for i in range(1, 5): | ||
special_tokens_to_add.append(f"[VERSE {i}]") | ||
special_tokens_to_add.append("[CHORUS]") | ||
special_tokens_to_add.append("[BRIDGE]") | ||
self.processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add}) | ||
self.model.resize_token_embeddings(len(self.processor.tokenizer)) | ||
|
||
def process_dataset(self) -> None: | ||
""" | ||
A method that processes the dataset. | ||
:return: None | ||
""" | ||
def prepare_dataset(example): | ||
audio = example["audio"] | ||
|
||
example = self.processor( | ||
audio=audio["array"], | ||
sampling_rate=audio["sampling_rate"], | ||
text=example["lyrics"], | ||
) | ||
|
||
# compute input length of audio sample in seconds | ||
example["input_length"] = len(audio["array"]) / audio["sampling_rate"] | ||
|
||
return example | ||
|
||
self.dataset = self.dataset.map( | ||
prepare_dataset, | ||
remove_columns=self.dataset.column_names["train"], | ||
) | ||
|
||
def compute_metrics(self, pred): | ||
pred_ids = pred.predictions | ||
label_ids = pred.label_ids | ||
|
||
# replace -100 with the pad_token_id | ||
label_ids[label_ids == -100] = self.processor.tokenizer.pad_token_id | ||
|
||
# we do not want to group tokens when computing the metrics | ||
pred_str = self.processor.batch_decode(pred_ids, skip_special_tokens=True) | ||
label_str = self.processor.batch_decode(label_ids, skip_special_tokens=True) | ||
|
||
# compute orthographic wer | ||
wer_ortho = 100 * METRIC.compute(predictions=pred_str, references=label_str) | ||
|
||
# compute normalised WER | ||
pred_str_norm = [NORMALIZER(pred) for pred in pred_str] | ||
label_str_norm = [NORMALIZER(label) for label in label_str] | ||
# filtering step to only evaluate the samples that correspond to non-zero references: | ||
pred_str_norm = [ | ||
pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0 | ||
] | ||
label_str_norm = [ | ||
label_str_norm[i] | ||
for i in range(len(label_str_norm)) | ||
if len(label_str_norm[i]) > 0 | ||
] | ||
|
||
wer = 100 * METRIC.compute(predictions=pred_str_norm, references=label_str_norm) | ||
|
||
return {"wer_ortho": wer_ortho, "wer": wer} | ||
|
||
def train(self): | ||
""" | ||
A method that trains the model. | ||
:return: | ||
""" | ||
training_args = Seq2SeqTrainingArguments( | ||
output_dir="./train", | ||
per_device_train_batch_size=4, | ||
per_device_eval_batch_size=4, | ||
learning_rate=1e-5, | ||
lr_scheduler_type="linear", | ||
warmup_steps=50, | ||
gradient_checkpointing=False, | ||
fp16=not torch.cuda.is_bf16_supported(), | ||
bf16=torch.cuda.is_bf16_supported(), | ||
bf16_full_eval=torch.cuda.is_bf16_supported(), | ||
fp16_full_eval=not torch.cuda.is_bf16_supported(), | ||
evaluation_strategy="epoch", | ||
optim="adamw_8bit", | ||
predict_with_generate=True, | ||
generation_max_length=225, | ||
logging_steps=25, | ||
metric_for_best_model="wer", | ||
greater_is_better=False, | ||
push_to_hub=True, | ||
) | ||
|
||
trainer = Seq2SeqTrainer( | ||
args=training_args, | ||
model=self.model, | ||
train_dataset=self.dataset["train"], | ||
eval_dataset=self.dataset["test"], | ||
data_collator=self.data_collator, | ||
# compute_metrics=self.compute_metrics, | ||
tokenizer=self.processor, | ||
) | ||
return trainer.train() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
""" | ||
""" | ||
import glob | ||
import warnings | ||
from dataclasses import dataclass | ||
from io import BytesIO | ||
from typing import Optional, Dict, Union | ||
|
||
import librosa | ||
import soundfile as sf | ||
from datasets import Dataset, Audio | ||
from tqdm import tqdm | ||
|
||
|
||
def reformat_audio(): | ||
"""Function that reformats the audio files.""" | ||
print("Reformatting audio files...") | ||
with warnings.catch_warnings(): | ||
warnings.filterwarnings("ignore", category=FutureWarning) | ||
warnings.filterwarnings("ignore", category=UserWarning) | ||
bar = tqdm(total=len(glob.glob("./dataset/audio/*"))) | ||
for audio in glob.glob("./dataset/audio/*"): | ||
try: | ||
y, sr = sf.read(audio) | ||
if audio.split(".")[-1] != "ogg" and sr != 16000: | ||
y = librosa.resample(y, orig_sr=sr, target_sr=16000) | ||
audio = audio.replace(extension, "ogg") | ||
sf.write(audio, y, sr, format='ogg', subtype='vorbis') | ||
except Exception: | ||
extension = audio.split(".")[-1] | ||
y, sr = librosa.load(audio, sr=None) | ||
y = librosa.resample(y, orig_sr=sr, target_sr=16000) | ||
audio = audio.replace(extension, "ogg") | ||
sf.write(audio, y, sr, format='ogg', subtype='vorbis') | ||
bar.update(1) | ||
|
||
|
||
def gather_dataset(path: str) -> Dataset: | ||
"""Function that gathers the dataset. | ||
Args: | ||
path (str): The path to the dataset. | ||
Returns: | ||
Dataset: The dataset. | ||
""" | ||
def gen(): | ||
i = 0 | ||
audio = glob.glob(path + "/audio/*.ogg") | ||
lyrics = glob.glob(path + "/lyrics/*.txt") | ||
for i in range(len(lyrics)): | ||
yield { | ||
"audio": audio[i], | ||
"lyrics": open(lyrics[i], "r").read(), | ||
} | ||
# reformat_audio() | ||
return Dataset.from_generator(gen).cast_column("audio", Audio()) |