Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement WebUI #13

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 14 additions & 18 deletions models/megatts2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -323,18 +325,16 @@ def __init__(
self.hifi_gan.eval()

def forward(
self,
wavs_dir: str,
text: str,
self,
audio_paths: List[str],
text: str
):
mels_prompt = None
# Make mrte mels
wavs = glob.glob(f'{wavs_dir}/*.wav')
mels = torch.empty(0)
for wav in wavs:
y = librosa.load(wav, sr=HIFIGAN_SR)[0]
for audio_path in audio_paths:
y = librosa.load(audio_path, sr=HIFIGAN_SR)[0]
y = librosa.util.normalize(y)
# y = librosa.effects.trim(y, top_db=20)[0]
y = torch.from_numpy(y)

mel_spec = extract_mel_spec(y).transpose(0, 1)
Expand All @@ -346,30 +346,26 @@ def forward(
mels = mels.unsqueeze(0)

# G2P
phone_tokens = self.ttc.phone2token(
self.tt.tokenize_lty(self.tt.tokenize(text)))
phone_tokens = self.ttc.phone2token(self.tt.tokenize_lty(self.tt.tokenize(text)))
phone_tokens = phone_tokens.unsqueeze(0)

with torch.no_grad():
tc_latent = self.generator.mrte.tc_latent(phone_tokens, mels)
dt = self.adm.infer(tc_latent)[..., 0]
tc_latent_expand = self.lr(tc_latent, dt)
tc_latent = F.max_pool1d(tc_latent_expand.transpose(
1, 2), 8, ceil_mode=True).transpose(1, 2)
tc_latent = F.max_pool1d(tc_latent_expand.transpose(1, 2), 8, ceil_mode=True).transpose(1, 2)
p_codes = self.plm.infer(tc_latent)

zq = self.generator.vqpe.vq.decode(p_codes.unsqueeze(0))
zq = rearrange(
zq, "B D T -> B T D").unsqueeze(2).contiguous().expand(-1, -1, 8, -1)
zq = rearrange(zq, "B D T -> B T D").unsqueeze(2).contiguous().expand(-1, -1, 8, -1)
zq = rearrange(zq, "B T S D -> B (T S) D")
x = torch.cat(
[tc_latent_expand, zq[:, :tc_latent_expand.shape[1], :]], dim=-1)
x = torch.cat([tc_latent_expand, zq[:, :tc_latent_expand.shape[1], :]], dim=-1)
x = rearrange(x, 'B T D -> B D T')
x = self.generator.decoder(x)

audio = self.hifi_gan.decode_batch(x.cpu())
audio_prompt = self.hifi_gan.decode_batch(
mels_prompt.unsqueeze(0).transpose(1, 2).cpu())
audio_prompt = self.hifi_gan.decode_batch(mels_prompt.unsqueeze(0).transpose(1, 2).cpu())
audio = torch.cat([audio_prompt, audio], dim=-1)

torchaudio.save('test.wav', audio[0], HIFIGAN_SR)
return audio.squeeze()

3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ torchaudio==2.1.0+cu118
torchvision==0.16.0+cu118
lightning==2.1.2
lhotse==1.17.0
h5py
h5py
gradio
44 changes: 44 additions & 0 deletions webui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import gradio as gr
from models.megatts import Megatts
from modules.tokenizer import HIFIGAN_SR

megatts = Megatts(
g_ckpt='generator.ckpt',
g_config='configs/config_gan.yaml',
plm_ckpt='plm.ckpt',
plm_config='configs/config_plm.yaml',
adm_ckpt='adm.ckpt',
adm_config='configs/config_adm.yaml',
symbol_table='unique_text_tokens.k2symbols'
)
megatts.eval()

def generate_audio(
audio_files,
text
):
audio_paths = [audio_file.name for audio_file in audio_files]
audio_tensor = megatts.forward(audio_paths, text)
audio_numpy = audio_tensor.cpu().numpy()
return audio_numpy, HIFIGAN_SR

iface = gr.Interface(
fn=generate_audio,
inputs=[
gr.inputs.File(
type="file",
label="Upload Audio Files",
multiple=True,
filetype="audio/wav"
),
gr.inputs.Textbox(lines=2, label="Input Text")
],
outputs=[
gr.outputs.Audio(type="numpy", label="Generated Audio")
],
title="MegaTTS2 Speech Synthesis",
description="Upload your audio files (only .wav format) and enter text to generate speech."
)

if __name__ == "__main__":
iface.launch()