Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MelSTFTLoss does not correctly register filterbanks as buffers #34

Open
csteinmetz1 opened this issue Jun 5, 2022 · 0 comments
Open

Comments

@csteinmetz1
Copy link
Owner

# setup mel filterbank
if self.scale == "mel":
    assert sample_rate != None  # Must set sample rate to use mel scale
    assert n_bins <= fft_size  # Must be more FFT bins than Mel bins
    fb = librosa.filters.mel(sample_rate, fft_size, n_mels=n_bins)
    self.fb = torch.tensor(fb).unsqueeze(0)
elif self.scale == "chroma":
    assert sample_rate != None  # Must set sample rate to use chroma scale
    assert n_bins <= fft_size  # Must be more FFT bins than chroma bins
    fb = librosa.filters.chroma(sample_rate, fft_size, n_chroma=n_bins)
    self.fb = torch.tensor(fb).unsqueeze(0)

if scale is not None and device is not None:
    self.fb = self.fb.to(self.device)  # move filterbank to device

This causes an issue when trying to compute the loss term on GPU as the self.fb object will not get moved automatically to the correct device. This is simple to resolve and should only require registering the filterbank object as a buffer.

Something like self.register_buffer("fb", fb)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant