Skip to content

Commit

Permalink
[update] add common voice fr
Browse files Browse the repository at this point in the history
  • Loading branch information
Jourdelune committed Jun 9, 2024
1 parent 6f59154 commit 1f71ad5
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 16 deletions.
57 changes: 57 additions & 0 deletions common_voice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from datasets import Audio, DatasetDict, load_dataset

from training.train import Trainer

common_voice = DatasetDict()

common_voice["train"] = load_dataset(
"mozilla-foundation/common_voice_11_0",
"fr",
split="validation[0:5000]",
use_auth_token=True,
)
common_voice["test"] = load_dataset(
"mozilla-foundation/common_voice_11_0",
"fr",
split="test[0:100]",
use_auth_token=True,
)

common_voice = common_voice.remove_columns(
[
"accent",
"age",
"client_id",
"down_votes",
"gender",
"locale",
"path",
"segment",
"up_votes",
]
)

common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

trainer = Trainer()


def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]

# compute log-Mel input features from input audio array
batch["input_features"] = trainer.feature_extractor(
audio["array"], sampling_rate=audio["sampling_rate"]
).input_features[0]

# encode target text to label ids
batch["labels"] = trainer.tokenizer(batch["sentence"]).input_ids
return batch


common_voice = common_voice.map(
prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=1
)

trainer.train(common_voice)
14 changes: 11 additions & 3 deletions train2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import librosa
import numpy as np
from datasets import Audio, DatasetDict, load_dataset
from datasets import Audio, DatasetDict, load_from_disk

from training import utils
from training.train import Trainer
Expand All @@ -12,6 +12,7 @@

is_prepared = False


if not is_prepared:
target_sr = trainer.processor.feature_extractor.sampling_rate

Expand All @@ -28,14 +29,21 @@ def prepare_dataset(batch):
batch["labels"] = trainer.tokenizer(batch["lyrics"]).input_ids
return batch

dataset = dataset.map(prepare_dataset, num_proc=1)
dataset = dataset.map(
prepare_dataset, remove_columns=dataset.column_names, num_proc=1
)

# filter out samples with empty labels
dataset = dataset.filter(lambda x: len(x["labels"]) > 5)

# save the processed dataset
dataset.save_to_disk("dataset/test/")

else:
# load the processed dataset
dataset = load_dataset("dataset/test/")
dataset = load_from_disk("dataset/test/")

print(dataset)

dataset = dataset.train_test_split(test_size=0.05)
trainer.train(dataset)
26 changes: 13 additions & 13 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
WhisperForConditionalGeneration,
WhisperProcessor,
WhisperTokenizer,
logging,
)

from training.collator import DataCollatorSpeechSeq2SeqWithPadding

logging.set_verbosity_warning()


class Trainer:
"""
Expand All @@ -23,30 +26,23 @@ class Trainer:
def __init__(
self,
model_name="openai/whisper-tiny",
language="hindi",
task="transcribe",
output_dir="./whisper-finetuned",
):
"""Function to initialize the Trainer class.
Args:
model_name (str, optional): _description_. Defaults to "openai/whisper-tiny".
language (str, optional): _description_. Defaults to "hindi".
task (str, optional): _description_. Defaults to "transcribe".
output_dir (str, optional): _description_. Defaults to "./whisper-finetuned".
"""

self.feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
self.tokenizer = WhisperTokenizer.from_pretrained(
model_name, language=language, task=task
)
self.tokenizer = WhisperTokenizer.from_pretrained(model_name, task=task)

self.processor = WhisperProcessor.from_pretrained(
model_name, language=language, task=task
)
self.processor = WhisperProcessor.from_pretrained(model_name, task=task)

self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
self.model.generation_config.language = language
self.model.generation_config.task = task

self.model.generation_config.forced_decoder_ids = None
Expand All @@ -70,8 +66,11 @@ def _compute_metrics(self, pred):
pred_str = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = self.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

print(pred_str[0])
print(label_str[0])

wer = 100 * self.metric.compute(predictions=pred_str, references=label_str)
print(f"WER: {wer}")

return {"wer": wer}

def train(self, dataset):
Expand All @@ -89,18 +88,19 @@ def train(self, dataset):
max_steps=4000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
eval_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=80,
eval_steps=40,
eval_steps=80,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
push_to_hub=False,
gradient_checkpointing_kwargs={"use_reentrant": False},
)

trainer = Seq2SeqTrainer(
Expand Down

0 comments on commit 1f71ad5

Please sign in to comment.