Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing STFT/Magnitude against STFTTflite/MagnitudeTflite #142

Open
daniel-deychakiwsky opened this issue Mar 14, 2024 · 2 comments
Open

Comments

@daniel-deychakiwsky
Copy link

I'm testing STFT to Magnitude layers and comparing them against the TFlite versions and get failures for the same set of arguments. Am I doing something incorrectly here or is this expected for some reason?

from kapre import (
    STFT,
    ApplyFilterbank,
    Magnitude,
    MagnitudeTflite,
    MagnitudeToDecibel,
    STFTTflite,
)
from tensorflow.keras.models import Sequential


def get_melgram_layer(
    n_fft: int,
    win_length: int,
    hop_length: int,
    window_name: str,
    pad_begin: bool,
    pad_end: bool,
    sample_rate: int,
    n_mels: int,
    mel_f_min: int,
    mel_f_max: int,
    mel_htk: bool,
    mel_norm: str,
    return_decibel: bool,
    db_amin: float,
    db_ref_value: float,
    db_dynamic_range: float,
    input_data_format: str,
    output_data_format: str,
    name: str,
    for_device: bool,
) -> Sequential:
    melgram_layers = Sequential(name=name)

    if for_device:
        melgram_layers.add(
            STFTTflite(
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                window_name=window_name,
                pad_begin=pad_begin,
                pad_end=pad_end,
                input_data_format=input_data_format,
                output_data_format=output_data_format,
            )
        )
        melgram_layers.add(MagnitudeTflite())
    else:
        melgram_layers.add(
            STFT(
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                window_name=window_name,
                pad_begin=pad_begin,
                pad_end=pad_end,
                input_data_format=input_data_format,
                output_data_format=output_data_format,
            )
        )
        melgram_layers.add(Magnitude())
    # melgram_layers.add(
    #     ApplyFilterbank(
    #         type="mel",
    #         filterbank_kwargs={
    #             "sample_rate": sample_rate,
    #             "n_freq": n_fft // 2 + 1,
    #             "n_mels": n_mels,
    #             "f_min": mel_f_min,
    #             "f_max": mel_f_max,
    #             "htk": mel_htk,
    #             "norm": mel_norm,
    #         },
    #         data_format=output_data_format,
    #     )
    # )
    # if return_decibel:
    #     melgram_layers.add(
    #         MagnitudeToDecibel(
    #             ref_value=db_ref_value, amin=db_amin, dynamic_range=db_dynamic_range
    #         )
    #     )
    return melgram_layers
import numpy as np


def test_get_melgram_layer():
    kwargs = {
        "n_fft": 2048,
        "win_length": 1024,
        "hop_length": 1024,
        "window_name": "hann_window",
        "pad_begin": False,
        "pad_end": False,
        "sample_rate": 22050,
        "n_mels": 256,
        "mel_f_min": 0,
        "mel_f_max": 22050 // 2,
        "mel_htk": False,
        "mel_norm": "slaney",
        "return_decibel": True,
        "db_amin": 1e-05,
        "db_ref_value": 1.0,
        "db_dynamic_range": 150.0,
        "input_data_format": "channels_last",
        "output_data_format": "channels_last",
        "name": "log_mel_spectrogram",
    }

    fake_audio = np.ones((1, 22050, 1))

    kwargs.update({"for_device": False})
    training_melgram = get_melgram_layer(**kwargs)(fake_audio)
    kwargs.update({"for_device": True})
    edge_serving_melgram = get_melgram_layer(**kwargs)(fake_audio)

    np.testing.assert_allclose(training_melgram, edge_serving_melgram)
>       np.testing.assert_allclose(training_melgram, edge_serving_melgram)
E       AssertionError: 
E       Not equal to tolerance rtol=1e-07, atol=0
E       
E       Mismatched elements: 21480 / 21525 (99.8%)
E       Max absolute difference: 0.00228808
E       Max relative difference: 2939.9194
E        x: array([[[[5.120000e+02],
E                [4.345991e+02],
E                [2.560000e+02],...
E        y: array([[[[5.120000e+02],
E                [4.345991e+02],
E                [2.560000e+02],...
@daniel-deychakiwsky
Copy link
Author

That used a DC signal above. If I use a more realistic signal, the difference is significantly reduced. I'm not sure if there's anything to be added to attempt to close the gap.

@keunwoochoi
Copy link
Owner

hi. there have been some discussion on it. in general, it's possible to make them close enough with parameter tuned, though it may not be apparent.

one thing: better to compare it without the decibel scaling, as it exaggerates the numerical difference under abs(x) < 1.

https://groups.google.com/a/ismir.net/g/community/c/LiVRv4I7asw/m/H6Ag-MxGAQAJ

https://colab.research.google.com/drive/1ptS1UkpHa-dW8w7WEf8xTE63mEQg8NQZ

tensorflow/tensorflow#32373

tensorflow/tensorflow#15134

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants