Skip to content

Commit

Permalink
[update] process original ds
Browse files Browse the repository at this point in the history
  • Loading branch information
Jourdelune committed Jun 8, 2024
1 parent 99fc968 commit 6f59154
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 47 deletions.
73 changes: 29 additions & 44 deletions train2.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,41 @@
from datasets import load_dataset, DatasetDict
from datasets import Audio
import librosa
import numpy as np
from datasets import Audio, DatasetDict, load_dataset

from training import utils
from training.train import Trainer

common_voice = DatasetDict()

common_voice["train"] = load_dataset(
"mozilla-foundation/common_voice_11_0",
"hi",
split="train+validation",
use_auth_token=True,
)
common_voice["test"] = load_dataset(
"mozilla-foundation/common_voice_11_0", "hi", split="test", use_auth_token=True
)


common_voice = common_voice.remove_columns(
[
"accent",
"age",
"client_id",
"down_votes",
"gender",
"locale",
"path",
"segment",
"up_votes",
]
)

common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
DS_PATH = "dataset/"

dataset = utils.gather_dataset(DS_PATH)
trainer = Trainer()

is_prepared = False

def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
if not is_prepared:
target_sr = trainer.processor.feature_extractor.sampling_rate

# compute log-Mel input features from input audio array
batch["input_features"] = trainer.feature_extractor(
audio["array"], sampling_rate=audio["sampling_rate"]
).input_features[0]
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio, _ = librosa.load(batch["audio"], sr=target_sr)

# encode target text to label ids
batch["labels"] = trainer.tokenizer(batch["sentence"]).input_ids
return batch
# compute log-Mel input features from input audio array
batch["input_features"] = trainer.feature_extractor(
audio, sampling_rate=target_sr
).input_features[0]

# encode target text to label ids
batch["labels"] = trainer.tokenizer(batch["lyrics"]).input_ids
return batch

common_voice = common_voice.map(
prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=1
)
dataset = dataset.map(prepare_dataset, num_proc=1)

trainer.train(common_voice)
# save the processed dataset
dataset.save_to_disk("dataset/test/")

else:
# load the processed dataset
dataset = load_dataset("dataset/test/")

dataset = dataset.train_test_split(test_size=0.05)
trainer.train(dataset)
6 changes: 3 additions & 3 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _compute_metrics(self, pred):
label_str = self.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

wer = 100 * self.metric.compute(predictions=pred_str, references=label_str)

print(f"WER: {wer}")
return {"wer": wer}

def train(self, dataset):
Expand All @@ -93,8 +93,8 @@ def train(self, dataset):
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=10,
eval_steps=10,
save_steps=80,
eval_steps=40,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
Expand Down

0 comments on commit 6f59154

Please sign in to comment.