Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fir filter input #65

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Conversation

jmoso13
Copy link

@jmoso13 jmoso13 commented Oct 20, 2023

In this PR I add optional FIRFilter input to STFTLoss, this filter automatically fills self.prefilter if available and sets it to None if not provided. If only perceptual_weighting flag is set, self.prefilter is set with internally constructed FIRFilter.

If both an external FIRFilter is provided and perceptual_weighting flag is set, an nn.Sequential variation (that allows for two inputs) is constructed to run both filters sequentially.

Tested on some audio input and appears to be working as expected. Below are some spectrograms of the different variations.

No Filter:
audio_no_filter

Only Perceptual Weighting:
audio_only_percep_weight

Only External 4.5k Lowpass Filter:
audio_only_4k_lowpass

Both Filters:
audio_percep_weight_and_4k_lowpass

Also included are the changes to auraloss.perceptual.FIRFilter which allows for for butterworth filter construction and a FIRSequential class in auraloss.utils that inherits from nn.Sequential and allows for multiple inputs.

I haven't tried using this branch in a model yet but it has worked returning losses as expected in my testing of just this repo.

Copy link
Owner

@csteinmetz1 csteinmetz1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for these changes. Adding custom FIRFilter instances to the STFT loss will be great. I left some suggestions in my comments. Overall, let's try to remove the FIRSequential and retain current default behavior by creating two prefilters. One is the default prefilter and the new one is the custom_prefilter. Also, to make the FIRFilter class cleaner, we will want to move the butterworth parameter specification outside and perhaps pass in taps. Will need some more thinking. Thanks for your work on this and feel free to tell me different on any of my comments.

@@ -181,7 +184,10 @@ def __init__(
raise ValueError(
f"`sample_rate` must be supplied when `perceptual_weighting = True`."
)
self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate)
if self.prefilter is None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we adjust the logic here? My thinking is that if someone sets perceptual_weighting=True we should apply the default filter. If someone also specifies prefilter=FIRFilter() then we should also apply this filter. This would create two separate logic branches here.

if self.prefilter is None:
self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate)
else:
self.prefilter = FIRSequential(FIRFilter(filter_type="aw", fs=sample_rate), self.prefilter)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the above, I think this means we can simplify this by removing the need for FIRSequential. To achieve this we could think of having two separate prefilter attributes. For example, for the user specified prefilter, self.user_prefilter and for the perceptual weighting we could set self.prefilter, as it is now. Then, in the forward() we can check for either and apply if needed. How does this sound?

@@ -209,7 +215,7 @@ def stft(self, x):
def forward(self, input: torch.Tensor, target: torch.Tensor):
bs, chs, seq_len = input.size()

if self.perceptual_weighting: # apply optional A-weighting via FIR filter
if self.prefilter is not None: # apply prefilter
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we can add two checks, one for if self.prefilter is not None: and one for if self.user_prefilter is not None:.

@@ -8,3 +8,9 @@ def apply_reduction(losses, reduction="none"):
elif reduction == "sum":
losses = losses.sum()
return losses

class FIRSequential(torch.nn.Sequential):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the above, ideally we can completely remove the need for this class.

@@ -55,14 +55,17 @@ class FIRFilter(torch.nn.Module):
a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates.
"""

def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False):
def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, butter_order=2, butter_freq=(250, 5000), butter_filter_type="bandpass", plot=False):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding the butterworth parameters to this constructor make the behavior a bit confusing since we also have the filter_type parameter which will get ignored. I do not know the best solution yet, but it feels like we should instead just pass taps into this constructor. Then we will need to supply some helper functions for hp and butterworth filters that will produce a tensor of taps. What do you think?

@@ -132,6 +134,7 @@ def __init__(
self.reduction = reduction
self.mag_distance = mag_distance
self.device = device
self.prefilter = prefilter
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I suggest below, let's change this to something like self.user_prefilter or another name, maybe self.custom_prefilter, which is totally separate from self.prefilter, which will be used for the default perceptual weighting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants