-
Notifications
You must be signed in to change notification settings - Fork 2
/
onnx_profile.py
47 lines (37 loc) · 1.06 KB
/
onnx_profile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import onnxruntime as rt
import onnx
import numpy as np
from omegaconf import OmegaConf
from data import VocoderDataset
from model.utils import STFT
import sys
model = onnx.load("out/HiFiPLN/hifipln.onnx")
options = rt.SessionOptions()
options.enable_profiling = True
sess = rt.InferenceSession(
model.SerializeToString(),
providers=rt.get_available_providers(),
sess_options=options,
)
config = OmegaConf.load("configs/hifipln.yaml")
valid_dataset = VocoderDataset(config, "valid")
stft = STFT(
sample_rate=config.sample_rate,
n_fft=config.n_fft,
win_length=config.win_length,
hop_length=config.hop_length,
f_min=config.f_min,
f_max=config.f_max,
n_mels=config.n_mels,
)
# Process dataset using the onnx model
for i, d in enumerate(valid_dataset):
print(f"Processing {i}")
audio, f0 = d["audio"], d["pitch"]
mel = stft.get_mel(audio)
mel = mel.transpose(-1, -2)
print(mel.shape, f0.shape)
res = sess.run(None, {"mel": mel.numpy(), "f0": f0.numpy()})
break
prof_file = sess.end_profiling()
print(prof_file)