Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PatchInferer with AvgMerger and filter_fn leads to NaNs #7743

Open
nicholas-greig opened this issue May 6, 2024 · 2 comments
Open

PatchInferer with AvgMerger and filter_fn leads to NaNs #7743

nicholas-greig opened this issue May 6, 2024 · 2 comments

Comments

@nicholas-greig
Copy link

Describe the bug
On master currently, when using the PatchInferer class with an AvgMerger (the default Merger class), and a filter_fn, the counts will be zero everywhere the filter_fn filters a region. Then, when the AvgMerger.finalize() is called, the self.values attr of AvgMerger is in-place divided by the self.counts tensor. This is an issue, since the self.counts tensor is initialised to zero, and div by zero causes NaNs. So, everywhere that a filter_fn successfully filters a region, we get NaN outputs.

A quick inplace assignment to counts (to set counts to 1, for example), will set all of these values to zero after this inplace division, but if the output is supposed to be real valued/continuous, it might be better to inplace overwrite these values to be the smallest value possible (using torch.finfo(self.values.dtype).min or something similar). Monkey patching the outputs from an Inferer isn't the best situation, since a network can produce NaNs due to weights exploding or overflow during training, and masking this with by overwriting NaNs to zero would merely obfuscate that problem.

@KumoLiu
Copy link
Contributor

KumoLiu commented May 6, 2024

Hi @nicholas-greig, could you please share a small piece of code that I can reproduce the issue?

Thanks.

@nicholas-greig
Copy link
Author

from monai.inferers.splitter import SlidingWindowSplitter
from monai.inferers.inferer import PatchInferer
import torch 
H,W = 512,512
def filter_fn(x,location):
    if location[1]>H//2:
        return False
    return True
    
splitter = SlidingWindowSplitter(
    (128,128),
    overlap=0,
    offset=0,filter_fn=filter_fn
)

inferer = PatchInferer(
    splitter,
)
inputs = torch.randn((1,1,H,W))
outputs = inferer(inputs=inputs,
                  network = lambda x: x)

print(torch.sum(torch.isnan(outputs[0])))
import matplotlib.pyplot as plt
plt.imshow(torch.isnan(outputs[0]).squeeze())
plt.show()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants