From 00711a20b7a8dfb6d6d1d892216997daeeee6288 Mon Sep 17 00:00:00 2001 From: Samuel Oberhofer Date: Tue, 19 Dec 2023 11:12:13 +0100 Subject: [PATCH] Infer the device from the model parameters Infer the device from the model parameters --------- Co-authored-by: Samuel Oberhofer --- cam.py | 19 ++++++++++--------- pytorch_grad_cam/ablation_cam.py | 2 -- pytorch_grad_cam/ablation_cam_multilayer.py | 4 ++-- pytorch_grad_cam/base_cam.py | 9 +++------ pytorch_grad_cam/eigen_cam.py | 3 +-- pytorch_grad_cam/eigen_grad_cam.py | 4 ++-- pytorch_grad_cam/fullgrad_cam.py | 3 +-- pytorch_grad_cam/grad_cam.py | 3 +-- pytorch_grad_cam/grad_cam_elementwise.py | 3 +-- pytorch_grad_cam/grad_cam_plusplus.py | 4 ++-- pytorch_grad_cam/guided_backprop.py | 10 ++++------ pytorch_grad_cam/hirescam.py | 3 +-- pytorch_grad_cam/layer_cam.py | 2 -- pytorch_grad_cam/random_cam.py | 3 +-- pytorch_grad_cam/score_cam.py | 5 +---- pytorch_grad_cam/utils/model_targets.py | 6 ++++++ pytorch_grad_cam/xgrad_cam.py | 2 -- tests/test_context_release.py | 3 +-- tests/test_one_channel.py | 3 +-- tests/test_run_all_models.py | 3 +-- 20 files changed, 39 insertions(+), 55 deletions(-) diff --git a/cam.py b/cam.py index 0724c960..12962490 100644 --- a/cam.py +++ b/cam.py @@ -4,6 +4,7 @@ import numpy as np import torch from torchvision import models +from torchvision.models import ResNet50_Weights from pytorch_grad_cam import ( GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, EigenGradCAM, @@ -18,8 +19,8 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--use-cuda', action='store_true', default=False, - help='Use NVIDIA GPU acceleration') + parser.add_argument('--device', type=str, default=None, + help='Torch device to use') parser.add_argument( '--image-path', type=str, @@ -44,9 +45,9 @@ def get_args(): parser.add_argument('--output-dir', type=str, default='output', help='Output directory to save the images') args = parser.parse_args() - args.use_cuda = args.use_cuda and torch.cuda.is_available() - if args.use_cuda: - print('Using GPU for acceleration') + + if args.device: + print(f'Using device "{args.device}" for acceleration') else: print('Using CPU for computation') @@ -76,7 +77,7 @@ def get_args(): "gradcamelementwise": GradCAMElementWise } - model = models.resnet50(pretrained=True) + model = models.resnet50(weights=ResNet50_Weights.DEFAULT).to(args.device).eval() # Choose the target layer you want to compute the visualization for. # Usually this will be the last convolutional layer in the model. @@ -97,7 +98,7 @@ def get_args(): rgb_img = np.float32(rgb_img) / 255 input_tensor = preprocess_image(rgb_img, mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) + std=[0.229, 0.224, 0.225]).to(args.device) # We have to specify the target we want to generate # the Class Activation Maps for. @@ -111,7 +112,7 @@ def get_args(): cam_algorithm = methods[args.method] with cam_algorithm(model=model, target_layers=target_layers, - use_cuda=args.use_cuda) as cam: + device=args.device) as cam: # AblationCAM and ScoreCAM have batched implementations. @@ -127,7 +128,7 @@ def get_args(): cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR) - gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda) + gb_model = GuidedBackpropReLUModel(model=model, device=args.device) gb = gb_model(input_tensor, target_category=None) cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam]) diff --git a/pytorch_grad_cam/ablation_cam.py b/pytorch_grad_cam/ablation_cam.py index c45090a7..252b5b07 100644 --- a/pytorch_grad_cam/ablation_cam.py +++ b/pytorch_grad_cam/ablation_cam.py @@ -28,7 +28,6 @@ class AblationCAM(BaseCAM): def __init__(self, model: torch.nn.Module, target_layers: List[torch.nn.Module], - use_cuda: bool = False, reshape_transform: Callable = None, ablation_layer: torch.nn.Module = AblationLayer(), batch_size: int = 32, @@ -36,7 +35,6 @@ def __init__(self, super(AblationCAM, self).__init__(model, target_layers, - use_cuda, reshape_transform, uses_gradients=False) self.batch_size = batch_size diff --git a/pytorch_grad_cam/ablation_cam_multilayer.py b/pytorch_grad_cam/ablation_cam_multilayer.py index 9b9dc806..721bf7a7 100644 --- a/pytorch_grad_cam/ablation_cam_multilayer.py +++ b/pytorch_grad_cam/ablation_cam_multilayer.py @@ -57,9 +57,9 @@ def replace_layer_recursive(model, old_layer, new_layer): class AblationCAM(BaseCAM): - def __init__(self, model, target_layers, use_cuda=False, + def __init__(self, model, target_layers, reshape_transform=None): - super(AblationCAM, self).__init__(model, target_layers, use_cuda, + super(AblationCAM, self).__init__(model, target_layers, reshape_transform) if len(target_layers) > 1: diff --git a/pytorch_grad_cam/base_cam.py b/pytorch_grad_cam/base_cam.py index e9a079c6..81a38da1 100644 --- a/pytorch_grad_cam/base_cam.py +++ b/pytorch_grad_cam/base_cam.py @@ -12,16 +12,14 @@ class BaseCAM: def __init__(self, model: torch.nn.Module, target_layers: List[torch.nn.Module], - use_cuda: bool = False, reshape_transform: Callable = None, compute_input_gradient: bool = False, uses_gradients: bool = True, tta_transforms: Optional[tta.Compose] = None) -> None: self.model = model.eval() self.target_layers = target_layers - self.cuda = use_cuda - if self.cuda: - self.model = model.cuda() + self.device = next(self.model.parameters()).device + self.reshape_transform = reshape_transform self.compute_input_gradient = compute_input_gradient self.uses_gradients = uses_gradients @@ -75,8 +73,7 @@ def forward(self, targets: List[torch.nn.Module], eigen_smooth: bool = False) -> np.ndarray: - if self.cuda: - input_tensor = input_tensor.cuda() + input_tensor = input_tensor.to(self.device) if self.compute_input_gradient: input_tensor = torch.autograd.Variable(input_tensor, diff --git a/pytorch_grad_cam/eigen_cam.py b/pytorch_grad_cam/eigen_cam.py index fd6d6bc1..efe5133b 100644 --- a/pytorch_grad_cam/eigen_cam.py +++ b/pytorch_grad_cam/eigen_cam.py @@ -5,11 +5,10 @@ class EigenCAM(BaseCAM): - def __init__(self, model, target_layers, use_cuda=False, + def __init__(self, model, target_layers, reshape_transform=None): super(EigenCAM, self).__init__(model, target_layers, - use_cuda, reshape_transform, uses_gradients=False) diff --git a/pytorch_grad_cam/eigen_grad_cam.py b/pytorch_grad_cam/eigen_grad_cam.py index 3932a96d..7cd073e7 100644 --- a/pytorch_grad_cam/eigen_grad_cam.py +++ b/pytorch_grad_cam/eigen_grad_cam.py @@ -6,9 +6,9 @@ class EigenGradCAM(BaseCAM): - def __init__(self, model, target_layers, use_cuda=False, + def __init__(self, model, target_layers, reshape_transform=None): - super(EigenGradCAM, self).__init__(model, target_layers, use_cuda, + super(EigenGradCAM, self).__init__(model, target_layers, reshape_transform) def get_cam_image(self, diff --git a/pytorch_grad_cam/fullgrad_cam.py b/pytorch_grad_cam/fullgrad_cam.py index 1a2685ef..7760f5b9 100644 --- a/pytorch_grad_cam/fullgrad_cam.py +++ b/pytorch_grad_cam/fullgrad_cam.py @@ -9,7 +9,7 @@ class FullGrad(BaseCAM): - def __init__(self, model, target_layers, use_cuda=False, + def __init__(self, model, target_layers, reshape_transform=None): if len(target_layers) > 0: print( @@ -27,7 +27,6 @@ def layer_with_2D_bias(layer): self).__init__( model, target_layers, - use_cuda, reshape_transform, compute_input_gradient=True) self.bias_data = [self.get_bias_data( diff --git a/pytorch_grad_cam/grad_cam.py b/pytorch_grad_cam/grad_cam.py index 025bf45d..b03f69e1 100644 --- a/pytorch_grad_cam/grad_cam.py +++ b/pytorch_grad_cam/grad_cam.py @@ -3,14 +3,13 @@ class GradCAM(BaseCAM): - def __init__(self, model, target_layers, use_cuda=False, + def __init__(self, model, target_layers, reshape_transform=None): super( GradCAM, self).__init__( model, target_layers, - use_cuda, reshape_transform) def get_cam_weights(self, diff --git a/pytorch_grad_cam/grad_cam_elementwise.py b/pytorch_grad_cam/grad_cam_elementwise.py index 2698d474..2b98496c 100644 --- a/pytorch_grad_cam/grad_cam_elementwise.py +++ b/pytorch_grad_cam/grad_cam_elementwise.py @@ -4,14 +4,13 @@ class GradCAMElementWise(BaseCAM): - def __init__(self, model, target_layers, use_cuda=False, + def __init__(self, model, target_layers, reshape_transform=None): super( GradCAMElementWise, self).__init__( model, target_layers, - use_cuda, reshape_transform) def get_cam_image(self, diff --git a/pytorch_grad_cam/grad_cam_plusplus.py b/pytorch_grad_cam/grad_cam_plusplus.py index 4466826b..b592c923 100644 --- a/pytorch_grad_cam/grad_cam_plusplus.py +++ b/pytorch_grad_cam/grad_cam_plusplus.py @@ -5,9 +5,9 @@ class GradCAMPlusPlus(BaseCAM): - def __init__(self, model, target_layers, use_cuda=False, + def __init__(self, model, target_layers, reshape_transform=None): - super(GradCAMPlusPlus, self).__init__(model, target_layers, use_cuda, + super(GradCAMPlusPlus, self).__init__(model, target_layers, reshape_transform) def get_cam_weights(self, diff --git a/pytorch_grad_cam/guided_backprop.py b/pytorch_grad_cam/guided_backprop.py index 602fbf35..3e775cfc 100644 --- a/pytorch_grad_cam/guided_backprop.py +++ b/pytorch_grad_cam/guided_backprop.py @@ -44,12 +44,10 @@ def forward(self, input_img): class GuidedBackpropReLUModel: - def __init__(self, model, use_cuda): + def __init__(self, model, device): self.model = model self.model.eval() - self.cuda = use_cuda - if self.cuda: - self.model = self.model.cuda() + self.device = next(self.model.parameters()).device def forward(self, input_img): return self.model(input_img) @@ -76,8 +74,8 @@ def __call__(self, input_img, target_category=None): torch.nn.ReLU, GuidedBackpropReLUasModule()) - if self.cuda: - input_img = input_img.cuda() + + input_img = input_img.to(self.device) input_img = input_img.requires_grad_(True) diff --git a/pytorch_grad_cam/hirescam.py b/pytorch_grad_cam/hirescam.py index 381d8d45..ea300335 100644 --- a/pytorch_grad_cam/hirescam.py +++ b/pytorch_grad_cam/hirescam.py @@ -4,14 +4,13 @@ class HiResCAM(BaseCAM): - def __init__(self, model, target_layers, use_cuda=False, + def __init__(self, model, target_layers, reshape_transform=None): super( HiResCAM, self).__init__( model, target_layers, - use_cuda, reshape_transform) def get_cam_image(self, diff --git a/pytorch_grad_cam/layer_cam.py b/pytorch_grad_cam/layer_cam.py index 971443d7..a2d70a17 100644 --- a/pytorch_grad_cam/layer_cam.py +++ b/pytorch_grad_cam/layer_cam.py @@ -10,14 +10,12 @@ def __init__( self, model, target_layers, - use_cuda=False, reshape_transform=None): super( LayerCAM, self).__init__( model, target_layers, - use_cuda, reshape_transform) def get_cam_image(self, diff --git a/pytorch_grad_cam/random_cam.py b/pytorch_grad_cam/random_cam.py index 5bb6eccd..1922441f 100644 --- a/pytorch_grad_cam/random_cam.py +++ b/pytorch_grad_cam/random_cam.py @@ -3,14 +3,13 @@ class RandomCAM(BaseCAM): - def __init__(self, model, target_layers, use_cuda=False, + def __init__(self, model, target_layers, reshape_transform=None): super( RandomCAM, self).__init__( model, target_layers, - use_cuda, reshape_transform) def get_cam_weights(self, diff --git a/pytorch_grad_cam/score_cam.py b/pytorch_grad_cam/score_cam.py index da1f262c..275509ad 100644 --- a/pytorch_grad_cam/score_cam.py +++ b/pytorch_grad_cam/score_cam.py @@ -8,11 +8,9 @@ def __init__( self, model, target_layers, - use_cuda=False, reshape_transform=None): super(ScoreCAM, self).__init__(model, target_layers, - use_cuda, reshape_transform=reshape_transform, uses_gradients=False) @@ -26,8 +24,7 @@ def get_cam_weights(self, upsample = torch.nn.UpsamplingBilinear2d( size=input_tensor.shape[-2:]) activation_tensor = torch.from_numpy(activations) - if self.cuda: - activation_tensor = activation_tensor.cuda() + activation_tensor = activation_tensor.to(next(self.model.parameters()).device) upsampled = upsample(activation_tensor) diff --git a/pytorch_grad_cam/utils/model_targets.py b/pytorch_grad_cam/utils/model_targets.py index 489dd198..8b8389d4 100644 --- a/pytorch_grad_cam/utils/model_targets.py +++ b/pytorch_grad_cam/utils/model_targets.py @@ -61,6 +61,8 @@ def __init__(self, category, mask): self.mask = torch.from_numpy(mask) if torch.cuda.is_available(): self.mask = self.mask.cuda() + if torch.backends.mps.is_available(): + self.mask = self.mask.to("mps") def __call__(self, model_output): return (model_output[self.category, :, :] * self.mask).sum() @@ -86,6 +88,8 @@ def __call__(self, model_outputs): output = torch.Tensor([0]) if torch.cuda.is_available(): output = output.cuda() + elif torch.backends.mps.is_available(): + output = output.to("mps") if len(model_outputs["boxes"]) == 0: return output @@ -94,6 +98,8 @@ def __call__(self, model_outputs): box = torch.Tensor(box[None, :]) if torch.cuda.is_available(): box = box.cuda() + elif torch.backends.mps.is_available(): + box = box.to("mps") ious = torchvision.ops.box_iou(box, model_outputs["boxes"]) index = ious.argmax() diff --git a/pytorch_grad_cam/xgrad_cam.py b/pytorch_grad_cam/xgrad_cam.py index 81a920fe..4310e767 100644 --- a/pytorch_grad_cam/xgrad_cam.py +++ b/pytorch_grad_cam/xgrad_cam.py @@ -7,14 +7,12 @@ def __init__( self, model, target_layers, - use_cuda=False, reshape_transform=None): super( XGradCAM, self).__init__( model, target_layers, - use_cuda, reshape_transform) def get_cam_weights(self, diff --git a/tests/test_context_release.py b/tests/test_context_release.py index b40b415b..20540b94 100644 --- a/tests/test_context_release.py +++ b/tests/test_context_release.py @@ -57,8 +57,7 @@ def test_memory_usage_in_loop(numpy_image, batch_size, width, height, initial_memory = 0 for i in range(100): with cam_method(model=model, - target_layers=target_layers, - use_cuda=False) as cam: + target_layers=target_layers) as cam: grayscale_cam = cam(input_tensor=input_tensor, targets=targets, aug_smooth=aug_smooth, diff --git a/tests/test_one_channel.py b/tests/test_one_channel.py index b8d4e542..1a429c80 100644 --- a/tests/test_one_channel.py +++ b/tests/test_one_channel.py @@ -42,8 +42,7 @@ def test_memory_usage_in_loop(numpy_image, cam_method): print("input_tensor", input_tensor.shape) targets = None with cam_method(model=model, - target_layers=target_layers, - use_cuda=False) as cam: + target_layers=target_layers) as cam: grayscale_cam = cam(input_tensor=input_tensor, targets=targets) print(grayscale_cam.shape) diff --git a/tests/test_run_all_models.py b/tests/test_run_all_models.py index d9bb5cd7..41db4eb4 100644 --- a/tests/test_run_all_models.py +++ b/tests/test_run_all_models.py @@ -64,8 +64,7 @@ def test_all_cam_models_can_run(numpy_image, batch_size, width, height, target_layers.append(eval(f"model.{layer}")) cam = cam_method(model=model, - target_layers=target_layers, - use_cuda=False) + target_layers=target_layers) cam.batch_size = 4 if target_category is None: targets = None