Skip to content

Commit

Permalink
Infer the device from the model parameters
Browse files Browse the repository at this point in the history
Infer the device from the model parameters
---------

Co-authored-by: Samuel Oberhofer <[email protected]>
  • Loading branch information
soberhofer and Samuel Oberhofer committed Dec 19, 2023
1 parent a797af2 commit 00711a2
Show file tree
Hide file tree
Showing 20 changed files with 39 additions and 55 deletions.
19 changes: 10 additions & 9 deletions cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
from torchvision import models
from torchvision.models import ResNet50_Weights
from pytorch_grad_cam import (
GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
Expand All @@ -18,8 +19,8 @@

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--use-cuda', action='store_true', default=False,
help='Use NVIDIA GPU acceleration')
parser.add_argument('--device', type=str, default=None,
help='Torch device to use')
parser.add_argument(
'--image-path',
type=str,
Expand All @@ -44,9 +45,9 @@ def get_args():
parser.add_argument('--output-dir', type=str, default='output',
help='Output directory to save the images')
args = parser.parse_args()
args.use_cuda = args.use_cuda and torch.cuda.is_available()
if args.use_cuda:
print('Using GPU for acceleration')

if args.device:
print(f'Using device "{args.device}" for acceleration')
else:
print('Using CPU for computation')

Expand Down Expand Up @@ -76,7 +77,7 @@ def get_args():
"gradcamelementwise": GradCAMElementWise
}

model = models.resnet50(pretrained=True)
model = models.resnet50(weights=ResNet50_Weights.DEFAULT).to(args.device).eval()

# Choose the target layer you want to compute the visualization for.
# Usually this will be the last convolutional layer in the model.
Expand All @@ -97,7 +98,7 @@ def get_args():
rgb_img = np.float32(rgb_img) / 255
input_tensor = preprocess_image(rgb_img,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
std=[0.229, 0.224, 0.225]).to(args.device)

# We have to specify the target we want to generate
# the Class Activation Maps for.
Expand All @@ -111,7 +112,7 @@ def get_args():
cam_algorithm = methods[args.method]
with cam_algorithm(model=model,
target_layers=target_layers,
use_cuda=args.use_cuda) as cam:
device=args.device) as cam:


# AblationCAM and ScoreCAM have batched implementations.
Expand All @@ -127,7 +128,7 @@ def get_args():
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)

gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda)
gb_model = GuidedBackpropReLUModel(model=model, device=args.device)
gb = gb_model(input_tensor, target_category=None)

cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam])
Expand Down
2 changes: 0 additions & 2 deletions pytorch_grad_cam/ablation_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@ class AblationCAM(BaseCAM):
def __init__(self,
model: torch.nn.Module,
target_layers: List[torch.nn.Module],
use_cuda: bool = False,
reshape_transform: Callable = None,
ablation_layer: torch.nn.Module = AblationLayer(),
batch_size: int = 32,
ratio_channels_to_ablate: float = 1.0) -> None:

super(AblationCAM, self).__init__(model,
target_layers,
use_cuda,
reshape_transform,
uses_gradients=False)
self.batch_size = batch_size
Expand Down
4 changes: 2 additions & 2 deletions pytorch_grad_cam/ablation_cam_multilayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def replace_layer_recursive(model, old_layer, new_layer):


class AblationCAM(BaseCAM):
def __init__(self, model, target_layers, use_cuda=False,
def __init__(self, model, target_layers,
reshape_transform=None):
super(AblationCAM, self).__init__(model, target_layers, use_cuda,
super(AblationCAM, self).__init__(model, target_layers,
reshape_transform)

if len(target_layers) > 1:
Expand Down
9 changes: 3 additions & 6 deletions pytorch_grad_cam/base_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@ class BaseCAM:
def __init__(self,
model: torch.nn.Module,
target_layers: List[torch.nn.Module],
use_cuda: bool = False,
reshape_transform: Callable = None,
compute_input_gradient: bool = False,
uses_gradients: bool = True,
tta_transforms: Optional[tta.Compose] = None) -> None:
self.model = model.eval()
self.target_layers = target_layers
self.cuda = use_cuda
if self.cuda:
self.model = model.cuda()
self.device = next(self.model.parameters()).device

self.reshape_transform = reshape_transform
self.compute_input_gradient = compute_input_gradient
self.uses_gradients = uses_gradients
Expand Down Expand Up @@ -75,8 +73,7 @@ def forward(self,
targets: List[torch.nn.Module],
eigen_smooth: bool = False) -> np.ndarray:

if self.cuda:
input_tensor = input_tensor.cuda()
input_tensor = input_tensor.to(self.device)

if self.compute_input_gradient:
input_tensor = torch.autograd.Variable(input_tensor,
Expand Down
3 changes: 1 addition & 2 deletions pytorch_grad_cam/eigen_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@


class EigenCAM(BaseCAM):
def __init__(self, model, target_layers, use_cuda=False,
def __init__(self, model, target_layers,
reshape_transform=None):
super(EigenCAM, self).__init__(model,
target_layers,
use_cuda,
reshape_transform,
uses_gradients=False)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_grad_cam/eigen_grad_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@


class EigenGradCAM(BaseCAM):
def __init__(self, model, target_layers, use_cuda=False,
def __init__(self, model, target_layers,
reshape_transform=None):
super(EigenGradCAM, self).__init__(model, target_layers, use_cuda,
super(EigenGradCAM, self).__init__(model, target_layers,
reshape_transform)

def get_cam_image(self,
Expand Down
3 changes: 1 addition & 2 deletions pytorch_grad_cam/fullgrad_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class FullGrad(BaseCAM):
def __init__(self, model, target_layers, use_cuda=False,
def __init__(self, model, target_layers,
reshape_transform=None):
if len(target_layers) > 0:
print(
Expand All @@ -27,7 +27,6 @@ def layer_with_2D_bias(layer):
self).__init__(
model,
target_layers,
use_cuda,
reshape_transform,
compute_input_gradient=True)
self.bias_data = [self.get_bias_data(
Expand Down
3 changes: 1 addition & 2 deletions pytorch_grad_cam/grad_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@


class GradCAM(BaseCAM):
def __init__(self, model, target_layers, use_cuda=False,
def __init__(self, model, target_layers,
reshape_transform=None):
super(
GradCAM,
self).__init__(
model,
target_layers,
use_cuda,
reshape_transform)

def get_cam_weights(self,
Expand Down
3 changes: 1 addition & 2 deletions pytorch_grad_cam/grad_cam_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@


class GradCAMElementWise(BaseCAM):
def __init__(self, model, target_layers, use_cuda=False,
def __init__(self, model, target_layers,
reshape_transform=None):
super(
GradCAMElementWise,
self).__init__(
model,
target_layers,
use_cuda,
reshape_transform)

def get_cam_image(self,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_grad_cam/grad_cam_plusplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@


class GradCAMPlusPlus(BaseCAM):
def __init__(self, model, target_layers, use_cuda=False,
def __init__(self, model, target_layers,
reshape_transform=None):
super(GradCAMPlusPlus, self).__init__(model, target_layers, use_cuda,
super(GradCAMPlusPlus, self).__init__(model, target_layers,
reshape_transform)

def get_cam_weights(self,
Expand Down
10 changes: 4 additions & 6 deletions pytorch_grad_cam/guided_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ def forward(self, input_img):


class GuidedBackpropReLUModel:
def __init__(self, model, use_cuda):
def __init__(self, model, device):
self.model = model
self.model.eval()
self.cuda = use_cuda
if self.cuda:
self.model = self.model.cuda()
self.device = next(self.model.parameters()).device

def forward(self, input_img):
return self.model(input_img)
Expand All @@ -76,8 +74,8 @@ def __call__(self, input_img, target_category=None):
torch.nn.ReLU,
GuidedBackpropReLUasModule())

if self.cuda:
input_img = input_img.cuda()

input_img = input_img.to(self.device)

input_img = input_img.requires_grad_(True)

Expand Down
3 changes: 1 addition & 2 deletions pytorch_grad_cam/hirescam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@


class HiResCAM(BaseCAM):
def __init__(self, model, target_layers, use_cuda=False,
def __init__(self, model, target_layers,
reshape_transform=None):
super(
HiResCAM,
self).__init__(
model,
target_layers,
use_cuda,
reshape_transform)

def get_cam_image(self,
Expand Down
2 changes: 0 additions & 2 deletions pytorch_grad_cam/layer_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ def __init__(
self,
model,
target_layers,
use_cuda=False,
reshape_transform=None):
super(
LayerCAM,
self).__init__(
model,
target_layers,
use_cuda,
reshape_transform)

def get_cam_image(self,
Expand Down
3 changes: 1 addition & 2 deletions pytorch_grad_cam/random_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@


class RandomCAM(BaseCAM):
def __init__(self, model, target_layers, use_cuda=False,
def __init__(self, model, target_layers,
reshape_transform=None):
super(
RandomCAM,
self).__init__(
model,
target_layers,
use_cuda,
reshape_transform)

def get_cam_weights(self,
Expand Down
5 changes: 1 addition & 4 deletions pytorch_grad_cam/score_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ def __init__(
self,
model,
target_layers,
use_cuda=False,
reshape_transform=None):
super(ScoreCAM, self).__init__(model,
target_layers,
use_cuda,
reshape_transform=reshape_transform,
uses_gradients=False)

Expand All @@ -26,8 +24,7 @@ def get_cam_weights(self,
upsample = torch.nn.UpsamplingBilinear2d(
size=input_tensor.shape[-2:])
activation_tensor = torch.from_numpy(activations)
if self.cuda:
activation_tensor = activation_tensor.cuda()
activation_tensor = activation_tensor.to(next(self.model.parameters()).device)

upsampled = upsample(activation_tensor)

Expand Down
6 changes: 6 additions & 0 deletions pytorch_grad_cam/utils/model_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(self, category, mask):
self.mask = torch.from_numpy(mask)
if torch.cuda.is_available():
self.mask = self.mask.cuda()
if torch.backends.mps.is_available():
self.mask = self.mask.to("mps")

def __call__(self, model_output):
return (model_output[self.category, :, :] * self.mask).sum()
Expand All @@ -86,6 +88,8 @@ def __call__(self, model_outputs):
output = torch.Tensor([0])
if torch.cuda.is_available():
output = output.cuda()
elif torch.backends.mps.is_available():
output = output.to("mps")

if len(model_outputs["boxes"]) == 0:
return output
Expand All @@ -94,6 +98,8 @@ def __call__(self, model_outputs):
box = torch.Tensor(box[None, :])
if torch.cuda.is_available():
box = box.cuda()
elif torch.backends.mps.is_available():
box = box.to("mps")

ious = torchvision.ops.box_iou(box, model_outputs["boxes"])
index = ious.argmax()
Expand Down
2 changes: 0 additions & 2 deletions pytorch_grad_cam/xgrad_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ def __init__(
self,
model,
target_layers,
use_cuda=False,
reshape_transform=None):
super(
XGradCAM,
self).__init__(
model,
target_layers,
use_cuda,
reshape_transform)

def get_cam_weights(self,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_context_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def test_memory_usage_in_loop(numpy_image, batch_size, width, height,
initial_memory = 0
for i in range(100):
with cam_method(model=model,
target_layers=target_layers,
use_cuda=False) as cam:
target_layers=target_layers) as cam:
grayscale_cam = cam(input_tensor=input_tensor,
targets=targets,
aug_smooth=aug_smooth,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_one_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def test_memory_usage_in_loop(numpy_image, cam_method):
print("input_tensor", input_tensor.shape)
targets = None
with cam_method(model=model,
target_layers=target_layers,
use_cuda=False) as cam:
target_layers=target_layers) as cam:
grayscale_cam = cam(input_tensor=input_tensor,
targets=targets)
print(grayscale_cam.shape)
3 changes: 1 addition & 2 deletions tests/test_run_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def test_all_cam_models_can_run(numpy_image, batch_size, width, height,
target_layers.append(eval(f"model.{layer}"))

cam = cam_method(model=model,
target_layers=target_layers,
use_cuda=False)
target_layers=target_layers)
cam.batch_size = 4
if target_category is None:
targets = None
Expand Down

0 comments on commit 00711a2

Please sign in to comment.