Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metadata/Tracing tracking fails after catching an exception #7701

Open
idinsmore1 opened this issue Apr 23, 2024 · 4 comments
Open

Metadata/Tracing tracking fails after catching an exception #7701

idinsmore1 opened this issue Apr 23, 2024 · 4 comments

Comments

@idinsmore1
Copy link

idinsmore1 commented Apr 23, 2024

Describe the bug
When attempting to catch runtime errors due to CUDA OOM when postprocessing segmentations and switch postprocessing to CPU, a RuntimeError("Transform Tracing must be enabled to get the most recent transform.") always occurs. After this, every postprocessing operation fails due to the same error until a new dataloader is defined.

To Reproduce
A bit tricky, but depending on your gpu, run inference on a model that can have the output stored on the GPU, but when attempting to postprocess the results (particularly the Invertd transform), your machine runs out of GPU memory. This is an example loop that I have wrote that shows the logic flow. I have double checked that after the RuntimeError exception catches that MONAIEnvVars.trace_transform() == 1 and also preprocessing.tracing == gpu_postprocessing.tracing == cpu_postprocessing.tracing == True as well.

preprocessing = Compose([
            LoadImaged(keys=['image']),
            EnsureChannelFirstd(keys=['image']),
            ThresholdIntensityd(keys=['image'], threshold=task_config['percentile_95'], above=False, cval=task_config['percentile_95']),
            ThresholdIntensityd(keys=['image'], threshold=task_config['percentile_05'], above=True, cval=task_config['percentile_05']),
            NormalizeIntensityd(keys=['image'], subtrahend=task_config['mean'], divisor=task_config['std']),
            CropForegroundd(keys=["image"], source_key="image", allow_smaller=True, select_fn=lambda x: x > task_config['crop_threshold']),
            Orientationd(keys=['image'], axcodes='RAS'),
            Spacingd(keys=['image'], pixdim=task_config['spacing'], mode='bilinear'),
            EnsureTyped(keys=['image'], track_meta=True)
        ])
postprocessing_transform = Compose([
      Activationsd(keys=['pred'], softmax=True),
      AsDiscreted(keys=['pred'], argmax=True),
      Invertd(keys=['pred'], transform=preprocessing, orig_keys='image', meta_keys='image_meta_dict', nearest_interp=True, to_tensor=True),
      SqueezeDimd(keys=['pred'], dim=0),
      ToNumpyd(keys=['pred'], dtype=np.uint8)
  ])
gpu_postprocessing = Compose([EnsureTyped(keys=['pred'], device=device), postprocessing_transform])
cpu_postprocessing = Compose([EnsureTyped(keys=['pred'], device='cpu'), postprocessing_transform])

dataset = CacheDataset(data_dict, cache_rate=1.0, transform=preprocessing, num_workers=4)
dataloader = ThreadDataLoader(dataset, batch_size=1, num_workers=0, pin_memory=True)
# set up the adaptive inferer
inferer = SlidingWindowInfererAdapt(roi_size=model_config[task]['patch_size'], sw_batch_size=batch_size, overlap=0.5)
# Run the inference loop
with autocast():
    with torch.no_grad():
        for data in dataloader:
            images = data['image'].to(device)
            # Run inference
            start_time = time.time()
            pred = inferer(inputs=images, network=model)
            inference_time = round(time.time() - start_time, 2)
            data['pred'] = pred
            # Delete the images to save gpu memory
            del images 
            # Run postprocessing. Only have one item so take index 0
            processing_start = time.time()
            # Attempt to run postprocessing on GPU, if it fails due to OOM, run it on CPU
            # If the prediction is on CPU (== -1), we just go right to CPU postprocessing
            if data['pred'].get_device() != -1:
                try:
                    out = [gpu_postprocessing(i) for i in decollate_batch(data)][0]
                except RuntimeError as e: # this is almost always an OOM error
                    print('Switching to CPU for postprocessing')
                    out = [cpu_postprocessing(i) for i in decollate_batch(data)][0] # After the first attempt of this, every following postprocessing transformation fails
            else:
                out = [cpu_postprocessing(i) for i in decollate_batch(data)][0]
            write_prediction(out)
            del pred
            del out

Expected behavior
I would expect the loop to continue as expected, only switching the tensor to cpu for postprocessing and the remaining data in the dataloader to be unaffected.

Environment

Printing MONAI config...

MONAI version: 1.3.0
Numpy version: 1.26.3
Pytorch version: 2.1.2.post301
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f
MONAI file: /home//mambaforge/envs/monai/lib/python3.11/site-packages/monai/init.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.0
scikit-image version: 0.22.0
scipy version: 1.12.0
Pillow version: 10.2.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.66.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: 2.2.0
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

================================
Printing system config...

System: Linux
Linux version: Ubuntu 20.04.6 LTS
Platform: Linux-5.4.0-146-generic-x86_64-with-glibc2.31
Processor: x86_64
Machine: x86_64
Python version: 3.11.7
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: []
Num physical CPUs: 48
Num logical CPUs: 96
Num usable CPUs: 96
CPU usage (%): [8.1, 0.0, 100.0, 1.8, 0.0, 100.0, 1.8, 0.6, 1.2, 0.6, 2.4, 0.0, 0.0, 100.0, 1.2, 0.0, 0.0, 0.0, 0.0, 100.0, 0.0, 0.0, 0.0, 100.0, 0.6, 0.0, 0.6, 97.0, 0.0, 100.0, 0.0, 100.0, 0.0, 0.0, 0.0, 100.0, 0.0, 0.0, 0.0, 0.0, 1.2, 100.0, 100.0, 100.0, 100.0, 100.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 100.0, 0.0, 0.0, 0.6, 10.9, 100.0, 1.8, 8.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.7, 1.2, 0.0, 1.2, 0.0, 0.0, 0.0, 0.0, 1.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
CPU freq. (MHz): 3086
Load avg. in last 1, 5, 15 mins (%): [21.0, 19.7, 20.2]
Disk usage (%): 90.7
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 1510.6
Available memory (GB): 1290.1
Used memory (GB): 163.5

================================
Printing GPU config...

Num GPUs: 16
Has CUDA: True
CUDA version: 11.2
cuDNN enabled: True
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: None
cuDNN version: 8800
Current device: 0
Library compiled for CUDA architectures: ['sm_35', 'sm_50', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'compute_86']
GPU 0 Name: Tesla V100-SXM3-32GB
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 80
GPU 0 Total memory (GB): 31.7
GPU 0 CUDA capability (maj.min): 7.0
GPU 1 Name: Tesla V100-SXM3-32GB
GPU 1 Is integrated: False
GPU 1 Is multi GPU board: False
GPU 1 Multi processor count: 80
GPU 1 Total memory (GB): 31.7
GPU 1 CUDA capability (maj.min): 7.0
GPU 2 Name: Tesla V100-SXM3-32GB
GPU 2 Is integrated: False
GPU 2 Is multi GPU board: False
GPU 2 Multi processor count: 80
GPU 2 Total memory (GB): 31.7
GPU 2 CUDA capability (maj.min): 7.0
GPU 3 Name: Tesla V100-SXM3-32GB
GPU 3 Is integrated: False
GPU 3 Is multi GPU board: False
GPU 3 Multi processor count: 80
GPU 3 Total memory (GB): 31.7
GPU 3 CUDA capability (maj.min): 7.0
GPU 4 Name: Tesla V100-SXM3-32GB
GPU 4 Is integrated: False
GPU 4 Is multi GPU board: False
GPU 4 Multi processor count: 80
GPU 4 Total memory (GB): 31.7
GPU 4 CUDA capability (maj.min): 7.0
GPU 5 Name: Tesla V100-SXM3-32GB
GPU 5 Is integrated: False
GPU 5 Is multi GPU board: False
GPU 5 Multi processor count: 80
GPU 5 Total memory (GB): 31.7
GPU 5 CUDA capability (maj.min): 7.0
GPU 6 Name: Tesla V100-SXM3-32GB
GPU 6 Is integrated: False
GPU 6 Is multi GPU board: False
GPU 6 Multi processor count: 80
GPU 6 Total memory (GB): 31.7
GPU 6 CUDA capability (maj.min): 7.0
GPU 7 Name: Tesla V100-SXM3-32GB
GPU 7 Is integrated: False
GPU 7 Is multi GPU board: False
GPU 7 Multi processor count: 80
GPU 7 Total memory (GB): 31.7
GPU 7 CUDA capability (maj.min): 7.0
GPU 8 Name: Tesla V100-SXM3-32GB
GPU 8 Is integrated: False
GPU 8 Is multi GPU board: False
GPU 8 Multi processor count: 80
GPU 8 Total memory (GB): 31.7
GPU 8 CUDA capability (maj.min): 7.0
GPU 9 Name: Tesla V100-SXM3-32GB
GPU 9 Is integrated: False
GPU 9 Is multi GPU board: False
GPU 9 Multi processor count: 80
GPU 9 Total memory (GB): 31.7
GPU 9 CUDA capability (maj.min): 7.0
GPU 10 Name: Tesla V100-SXM3-32GB
GPU 10 Is integrated: False
GPU 10 Is multi GPU board: False
GPU 10 Multi processor count: 80
GPU 10 Total memory (GB): 31.7
GPU 10 CUDA capability (maj.min): 7.0
GPU 11 Name: Tesla V100-SXM3-32GB
GPU 11 Is integrated: False
GPU 11 Is multi GPU board: False
GPU 11 Multi processor count: 80
GPU 11 Total memory (GB): 31.7
GPU 11 CUDA capability (maj.min): 7.0
GPU 12 Name: Tesla V100-SXM3-32GB
GPU 12 Is integrated: False
GPU 12 Is multi GPU board: False
GPU 12 Multi processor count: 80
GPU 12 Total memory (GB): 31.7
GPU 12 CUDA capability (maj.min): 7.0
GPU 13 Name: Tesla V100-SXM3-32GB
GPU 13 Is integrated: False
GPU 13 Is multi GPU board: False
GPU 13 Multi processor count: 80
GPU 13 Total memory (GB): 31.7
GPU 13 CUDA capability (maj.min): 7.0
GPU 14 Name: Tesla V100-SXM3-32GB
GPU 14 Is integrated: False
GPU 14 Is multi GPU board: False
GPU 14 Multi processor count: 80
GPU 14 Total memory (GB): 31.7
GPU 14 CUDA capability (maj.min): 7.0
GPU 15 Name: Tesla V100-SXM3-32GB
GPU 15 Is integrated: False
GPU 15 Is multi GPU board: False
GPU 15 Multi processor count: 80
GPU 15 Total memory (GB): 31.7
GPU 15 CUDA capability (maj.min): 7.0

Additional context
Everything works as expected until the error is caught for the first time. CPU postprocessing and GPU postprocessing produce the same output as long as an error does not occur. It's also worth mentioning that the same inability to trace the transform occurs if I just try to catch the error without postprocessing so the loop does not stop - all remaining postprocessing fails due to the same tracing RuntimeError

@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 24, 2024

Hi @idinsmore1, I guess the error may be due to the ToNumpyd in the postprocessing_transform . Could you please remove it and try again?
Thanks.

@idinsmore1
Copy link
Author

idinsmore1 commented Apr 24, 2024

Hi @KumoLiu Unfortunately this did not work, here's the full traceback

OutOfMemoryError Traceback (most recent call last)
File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/transform.py:141, in apply_transform(transform, data, map_items, unpack_items, log_stats, lazy, overrides)
140 return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data]
--> 141 return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
142 except Exception as e:
143 # if in debug mode, don't swallow exception so that the breakpoint
144 # appears where the exception was raised.

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/transform.py:98, in _apply_transform(transform, data, unpack_parameters, lazy, overrides, logger_name)
96 return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data)
---> 98 return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/spatial/dictionary.py:527, in Spacingd.inverse(self, data)
526 for key in self.key_iterator(d):
--> 527 d[key] = self.spacing_transform.inverse(cast(torch.Tensor, d[key]))
528 return d

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/spatial/array.py:543, in Spacing.inverse(self, data)
542 def inverse(self, data: torch.Tensor) -> torch.Tensor:
--> 543 return self.sp_resample.inverse(data)

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/spatial/array.py:248, in SpatialResample.inverse(self, data)
246 with self.trace_transform(False):
247 # we can't use self.__call__ in case a child class calls this inverse.
--> 248 out: torch.Tensor = SpatialResample.call(self, data, **kw_args)
249 kw_args["src_affine"] = kw_args.get("dst_affine")

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/spatial/array.py:223, in SpatialResample.call(self, img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype, lazy)
222 lazy_ = self.lazy if lazy is None else lazy
--> 223 return spatial_resample(
224 img,
225 dst_affine,
226 spatial_size,
227 mode,
228 padding_mode,
229 align_corners,
230 dtype_pt,
231 lazy=lazy_,
232 transform_info=self.get_transform_info(),
233 )

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/spatial/functional.py:178, in spatial_resample(img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, lazy, transform_info)
175 affine_xform = AffineTransform( # type: ignore
176 normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True
177 )
--> 178 img = affine_xform(img.unsqueeze(0), theta=xform.to(img), spatial_size=spatial_size).squeeze(0) # type: ignore
179 if additional_dims:

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/networks/layers/spatial_transforms.py:579, in AffineTransform.forward(self, src, theta, spatial_size)
575 raise ValueError(
576 f"affine and image batch dimension must match, got affine={theta.shape[0]} image={src_size[0]}."
577 )
--> 579 grid = nn.functional.affine_grid(theta=theta[:, :sr], size=list(dst_size), align_corners=self.align_corners)
580 dst = nn.functional.grid_sample(
581 input=src.contiguous(),
582 grid=grid,
(...)
585 align_corners=self.align_corners,
586 )

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/torch/nn/functional.py:4399, in affine_grid(theta, size, align_corners)
4397 raise ValueError(f"Expected non-zero, positive output size. Got {size}")
-> 4399 return torch.affine_grid_generator(theta, size, align_corners)

OutOfMemoryError: CUDA out of memory. Tried to allocate 11.53 GiB. GPU 0 has a total capacty of 31.75 GiB of which 5.47 GiB is free. Including non-PyTorch memory, this process has 26.27 GiB memory in use. Of the allocated memory 24.98 GiB is allocated by PyTorch, and 158.96 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

The above exception was the direct cause of the following exception:

RuntimeError Traceback (most recent call last)
File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/transform.py:141, in apply_transform(transform, data, map_items, unpack_items, log_stats, lazy, overrides)
140 return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data]
--> 141 return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
142 except Exception as e:
143 # if in debug mode, don't swallow exception so that the breakpoint
144 # appears where the exception was raised.

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/transform.py:98, in _apply_transform(transform, data, unpack_parameters, lazy, overrides, logger_name)
96 return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data)
---> 98 return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/post/dictionary.py:706, in Invertd.call(self, data)
705 with allow_missing_keys_mode(self.transform): # type: ignore
--> 706 inverted = self.transform.inverse(input_dict)
708 # save the inverted data

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/compose.py:364, in Compose.inverse(self, data)
363 for t in reversed(invertible_transforms):
--> 364 data = apply_transform(
365 t.inverse, data, self.map_items, self.unpack_items, lazy=False, log_stats=self.log_stats
366 )
367 return data

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/transform.py:171, in apply_transform(transform, data, map_items, unpack_items, log_stats, lazy, overrides)
170 _log_stats(data=data)
--> 171 raise RuntimeError(f"applying transform {transform}") from e

RuntimeError: applying transform <bound method Spacingd.inverse of <monai.transforms.spatial.dictionary.Spacingd object at 0x7f2485cff450>>

The above exception was the direct cause of the following exception:

RuntimeError Traceback (most recent call last)
Cell In[6], line 15
14 print('postprocessing on gpu')
---> 15 out = [gpu_postprocessing(i) for i in decollate_batch(data)][0]
16 except RuntimeError as e:

Cell In[6], line 15, in (.0)
14 print('postprocessing on gpu')
---> 15 out = [gpu_postprocessing(i) for i in decollate_batch(data)][0]
16 except RuntimeError as e:

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/compose.py:335, in Compose.call(self, input_, start, end, threading, lazy)
334 _lazy = self.lazy if lazy is None else lazy
--> 335 result = execute_compose(
336 input
,
337 transforms=self.transforms,
338 start=start,
339 end=end,
340 map_items=self.map_items,
341 unpack_items=self.unpack_items,
342 lazy=_lazy,
343 overrides=self.overrides,
344 threading=threading,
345 log_stats=self.log_stats,
346 )
348 return result

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/compose.py:111, in execute_compose(data, transforms, map_items, unpack_items, start, end, lazy, overrides, threading, log_stats)
110 _transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
--> 111 data = apply_transform(
112 _transform, data, map_items, unpack_items, lazy=lazy, overrides=overrides, log_stats=log_stats
113 )
114 data = apply_pending_transforms(data, None, overrides, logger_name=log_stats)

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/transform.py:171, in apply_transform(transform, data, map_items, unpack_items, log_stats, lazy, overrides)
170 _log_stats(data=data)
--> 171 raise RuntimeError(f"applying transform {transform}") from e

RuntimeError: applying transform <monai.transforms.post.dictionary.Invertd object at 0x7f2485d07ad0>

During handling of the above exception, another exception occurred:

RuntimeError Traceback (most recent call last)
File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/transform.py:141, in apply_transform(transform, data, map_items, unpack_items, log_stats, lazy, overrides)
140 return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data]
--> 141 return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
142 except Exception as e:
143 # if in debug mode, don't swallow exception so that the breakpoint
144 # appears where the exception was raised.

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/transform.py:98, in _apply_transform(transform, data, unpack_parameters, lazy, overrides, logger_name)
96 return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data)
---> 98 return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/spatial/dictionary.py:527, in Spacingd.inverse(self, data)
526 for key in self.key_iterator(d):
--> 527 d[key] = self.spacing_transform.inverse(cast(torch.Tensor, d[key]))
528 return d

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/spatial/array.py:543, in Spacing.inverse(self, data)
542 def inverse(self, data: torch.Tensor) -> torch.Tensor:
--> 543 return self.sp_resample.inverse(data)

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/spatial/array.py:236, in SpatialResample.inverse(self, data)
235 def inverse(self, data: torch.Tensor) -> torch.Tensor:
--> 236 transform = self.pop_transform(data)
237 # Create inverse transform

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/inverse.py:328, in TraceableTransform.pop_transform(self, data, key, check)
314 """
315 Return and pop the most recent transform.
316
(...)
326 - RuntimeError: data is neither MetaTensor nor dictionary
327 """
--> 328 return self.get_most_recent_transform(data, key, check, pop=True)

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/inverse.py:299, in TraceableTransform.get_most_recent_transform(self, data, key, check, pop)
298 if not self.tracing:
--> 299 raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.")
300 if isinstance(data, MetaTensor):

RuntimeError: Transform Tracing must be enabled to get the most recent transform.

The above exception was the direct cause of the following exception:

RuntimeError Traceback (most recent call last)
File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/transform.py:141, in apply_transform(transform, data, map_items, unpack_items, log_stats, lazy, overrides)
140 return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data]
--> 141 return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
142 except Exception as e:
143 # if in debug mode, don't swallow exception so that the breakpoint
144 # appears where the exception was raised.

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/transform.py:98, in _apply_transform(transform, data, unpack_parameters, lazy, overrides, logger_name)
96 return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data)
---> 98 return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/post/dictionary.py:706, in Invertd.call(self, data)
705 with allow_missing_keys_mode(self.transform): # type: ignore
--> 706 inverted = self.transform.inverse(input_dict)
708 # save the inverted data

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/compose.py:364, in Compose.inverse(self, data)
363 for t in reversed(invertible_transforms):
--> 364 data = apply_transform(
365 t.inverse, data, self.map_items, self.unpack_items, lazy=False, log_stats=self.log_stats
366 )
367 return data

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/transform.py:171, in apply_transform(transform, data, map_items, unpack_items, log_stats, lazy, overrides)
170 _log_stats(data=data)
--> 171 raise RuntimeError(f"applying transform {transform}") from e

RuntimeError: applying transform <bound method Spacingd.inverse of <monai.transforms.spatial.dictionary.Spacingd object at 0x7f2485cff450>>

The above exception was the direct cause of the following exception:

RuntimeError Traceback (most recent call last)
Cell In[6], line 18
16 except RuntimeError as e:
17 print('switching to cpu')
---> 18 out = [cpu_postprocessing(i) for i in decollate_batch(data)][0]
20 # raise RuntimeError('test runtime error')
21 else:
22 print('postprocessing on cpu')

Cell In[6], line 18, in (.0)
16 except RuntimeError as e:
17 print('switching to cpu')
---> 18 out = [cpu_postprocessing(i) for i in decollate_batch(data)][0]
20 # raise RuntimeError('test runtime error')
21 else:
22 print('postprocessing on cpu')

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/compose.py:335, in Compose.call(self, input_, start, end, threading, lazy)
333 def call(self, input_, start=0, end=None, threading=False, lazy: bool | None = None):
334 _lazy = self.lazy if lazy is None else lazy
--> 335 result = execute_compose(
336 input
,
337 transforms=self.transforms,
338 start=start,
339 end=end,
340 map_items=self.map_items,
341 unpack_items=self.unpack_items,
342 lazy=_lazy,
343 overrides=self.overrides,
344 threading=threading,
345 log_stats=self.log_stats,
346 )
348 return result

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/compose.py:111, in execute_compose(data, transforms, map_items, unpack_items, start, end, lazy, overrides, threading, log_stats)
109 if threading:
110 _transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
--> 111 data = apply_transform(
112 _transform, data, map_items, unpack_items, lazy=lazy, overrides=overrides, log_stats=log_stats
113 )
114 data = apply_pending_transforms(data, None, overrides, logger_name=log_stats)
115 return data

File ~/mambaforge/envs/monai/lib/python3.11/site-packages/monai/transforms/transform.py:171, in apply_transform(transform, data, map_items, unpack_items, log_stats, lazy, overrides)
169 else:
170 _log_stats(data=data)
--> 171 raise RuntimeError(f"applying transform {transform}") from e

RuntimeError: applying transform <monai.transforms.post.dictionary.Invertd object at 0x7f2485cff890>

@idinsmore1
Copy link
Author

idinsmore1 commented Apr 24, 2024

I've been testing this, and I believe the error is stemming from the call to SpatialResample when performing the inverse of Spacingd after the error catch as that is the first transformation to be performed. I manually set all the transforms' tracing attributes to True using

preprocessing.tracing = True
for transform in preprocessing.transforms:
    transform.tracing = True
cpu_postprocessing.tracing = True
for transform in cpu_postprocessing.transforms:
    transform.tracing = True
gpu_postprocessing.tracing = True
for transform in gpu_postprocessing.transforms:
    transform.tracing = True

And the error still occurs. The output of both data['image'].applied_operations and data['pred'].applied_operations has all transformations set to tracing: True and I checked data['image'].applied_operations == data['pred'].applied_operations == True. The only transformation not listed here is SpatialResample, which would make some sense as to why this would not work even after manually setting this attribute.

@idinsmore1
Copy link
Author

Ok so actually got this working, I'm going to assume that this is not MONAI's expected/desired behavior in this instance. When running this inference loop, preprocessing.transforms[-2] is the Spacingd transform. Before the first instance of the exception, preprocessing.transforms[-2].spacing_transform.sp_resample.tracing == True, which is the tracing attribute for the SpatialResample call within Spacingd. After the catch of the error, preprocessing.transforms[-2].spacing_transform.sp_resample.tracing == False, which breaks the Invertd transform. So, if you manually reset the preprocessing tracing attributes like this:

def reset_tracing(preprocessing):
    preprocessing.tracing = True
    for transform in preprocessing.transforms:
         transform.tracing = True
    preprocessing.transforms[-2].spacing_transform.sp_resample.tracing = True
    return preprocessing

and insert this function into the exception, everything works as expected.

preprocessing = Compose([
            LoadImaged(keys=['image']),
            EnsureChannelFirstd(keys=['image']),
            ThresholdIntensityd(keys=['image'], threshold=task_config['percentile_95'], above=False, cval=task_config['percentile_95']),
            ThresholdIntensityd(keys=['image'], threshold=task_config['percentile_05'], above=True, cval=task_config['percentile_05']),
            NormalizeIntensityd(keys=['image'], subtrahend=task_config['mean'], divisor=task_config['std']),
            CropForegroundd(keys=["image"], source_key="image", allow_smaller=True, select_fn=lambda x: x > task_config['crop_threshold']),
            Orientationd(keys=['image'], axcodes='RAS'),
            Spacingd(keys=['image'], pixdim=task_config['spacing'], mode='bilinear'),
            EnsureTyped(keys=['image'], track_meta=True)
        ])
postprocessing_transform = Compose([
      Activationsd(keys=['pred'], softmax=True),
      AsDiscreted(keys=['pred'], argmax=True),
      Invertd(keys=['pred'], transform=preprocessing, orig_keys='image', meta_keys='image_meta_dict', nearest_interp=True, to_tensor=True),
      SqueezeDimd(keys=['pred'], dim=0),
      ToNumpyd(keys=['pred'], dtype=np.uint8)
  ])
gpu_postprocessing = Compose([EnsureTyped(keys=['pred'], device=device), postprocessing_transform])
cpu_postprocessing = Compose([EnsureTyped(keys=['pred'], device='cpu'), postprocessing_transform])

dataset = CacheDataset(data_dict, cache_rate=1.0, transform=preprocessing, num_workers=4)
dataloader = ThreadDataLoader(dataset, batch_size=1, num_workers=0, pin_memory=True)
# set up the adaptive inferer
inferer = SlidingWindowInfererAdapt(roi_size=model_config[task]['patch_size'], sw_batch_size=batch_size, overlap=0.5)
# Run the inference loop
with autocast():
    with torch.no_grad():
        for data in dataloader:
            images = data['image'].to(device)
            # Run inference
            start_time = time.time()
            pred = inferer(inputs=images, network=model)
            inference_time = round(time.time() - start_time, 2)
            data['pred'] = pred
            # Delete the images to save gpu memory
            del images 
            # Run postprocessing. Only have one item so take index 0
            processing_start = time.time()
            # Attempt to run postprocessing on GPU, if it fails due to OOM, run it on CPU
            # If the prediction is on CPU (== -1), we just go right to CPU postprocessing
            if data['pred'].get_device() != -1:
                try:
                    out = [gpu_postprocessing(i) for i in decollate_batch(data)][0]
                except RuntimeError as e: # this is almost always an OOM error
                    print('Switching to CPU for postprocessing')
                    preprocessing = reset_tracing(preprocessing)
                    out = [cpu_postprocessing(i) for i in decollate_batch(data)][0] # After the first attempt of this, every following postprocessing transformation fails
            else:
                out = [cpu_postprocessing(i) for i in decollate_batch(data)][0]
            write_prediction(out)
            del pred
            del out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants