diff --git a/pytorch_grad_cam/utils/roi.py b/pytorch_grad_cam/utils/roi.py new file mode 100644 index 00000000..aed8e60a --- /dev/null +++ b/pytorch_grad_cam/utils/roi.py @@ -0,0 +1,157 @@ +import numpy as np +import torch +import matplotlib.pyplot as plt +from skimage import measure +import collections + +def gui_get_point(image, i=None, j=None): + fig = plt.figure('Input Pick An Point') + scale = np.mean(image.shape[:2]) + if len(image.shape)==3: + pImg = plt.imshow(image) + else: + pImg = plt.matshow(image) + + pMarker = plt.scatter(j, i, c='r', s=scale, marker='x') + ret = plt.ginput(1) + if ret == None or ret == []: + pass + else: + j, i = ret[0] + j, i= int(j), int(i) + pMarker.remove() + pMarker = plt.scatter(j, i, c='r', s=scale, marker='x') + plt.close(fig) + return i, j + +class BaseROI: + def __init__(self, image = None): + self.image = image + self.roi = torch.Tensor([1]) + self.fullroi = None + self.i = None + self.j = None + + def setROIij(self): + print(f'Shape of ROI:{self.roi.shape}') + self.i = np.where(self.roi == 1)[0] + self.j = np.where(self.roi == 1)[1] + print(f'Lengths of i and j index lists: {len(self.i)}, {len(self.j)}') + + def meshgrid(self): + ylist = np.linspace(0, self.image.shape[0], self.image.shape[0]) + xlist = np.linspace(0, self.image.shape[1], self.image.shape[1]) + return np.meshgrid(xlist, ylist) + + def apply_roi(self, output): + return self.roi.to(output.device) * output + + +class PixelROI(BaseROI): + def __init__(self, i, j, image): + self.image = image + self.roi = torch.zeros((image.shape[-3], image.shape[-2])) + self.roi[i, j] = 1 + self.i = i + self.j = j + + def pickPixel(self): + self.i, self.j = gui_get_point(self.image, self.i, self.j) + self.roi.zero_() + self.roi[self.i, self.j] = 1 + print(f'ROI Point: {self.i},{self.j}') + +def filter_connected_components(values, counts, exclude): + selected_indices=[] + selected_counts=[] + selected_values=[] + for i in range(len(values)): + if values[i] != exclude: + selected_indices.append([i]) + selected_values.append(values[i]) + selected_counts.append(counts[i]) + return selected_indices, selected_values, selected_counts + +class ClassROI(BaseROI): + def __init__(self, image, pred, cls, background=0): + self.image = image + self.pred = pred + self.cls = cls + self.roi = (pred == cls).reshape(image.shape[-3], image.shape[-2]) + self.background = background + print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {self.cls}') + + def connectedComponents(self): + all_labels = measure.label(self.pred) + (values, counts) = np.unique(all_labels, return_counts=True) + print("connectedComponents values, counts: ", values, counts) + return all_labels, values, counts + + def largestComponent(self): + all_labels, values, counts = self.connectedComponents() + # find the largest component + selected_indices, selected_values, selected_counts = filter_connected_components(values, + counts, + self.background) + ind = selected_indices[np.argmax(selected_counts)] + print("largestComponent argmax: ", ind) + self.roi = torch.Tensor(all_labels == ind) + print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {values[ind]}') + + def smallestComponent(self): + all_labels, values, counts = self.connectedComponents() + selected_indices, selected_values, selected_counts = filter_connected_components(values, + counts, + self.background) + ind = selected_indices[np.argmin(selected_counts)] + print("smallestComponent argmin: ", ind) + self.roi = torch.Tensor(all_labels == ind) + print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {values[ind]}') + + def pickClass(self): + i, j = gui_get_point(self.pred) + self.cls = self.pred[i, j] + self.roi = (self.pred == self.cls).reshape(self.image.shape[-3], self.image.shape[-2]) + print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {self.cls}') + + def pickComponentClass(self): + i, j = gui_get_point(self.pred) + all_labels, values, counts = self.connectedComponents() + ind = all_labels[i, j] + self.cls = all_labels[i, j] + self.roi = torch.Tensor(all_labels == ind).reshape(self.image.shape[-3], self.image.shape[-2]) + print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {self.cls}') + +# Get tensor from output of network. Some segmentation network returns more than 1 tensor. +def get_output_tensor(output, verbose=False): + if isinstance(output, torch.Tensor): + return output + elif isinstance(output, collections.OrderedDict): + k = next(iter(output.keys())) + if verbose: print(f'Select "{k}" from dict {output.keys()}') + return output[k] + elif isinstance(output, list): + if verbose: print(f'Select "[0]" from list(n={len(output)})') + return output[0] + else: + raise RuntimeError(f'Unknown type {type(output)}') + +class SegModel(torch.nn.Module): + def __init__(self, model, roi=None): + super(SegModel, self).__init__() + self.model = model + self.roi = roi + + def forward(self, x): + output = self.model(x) # might be multiple tensors + output = get_output_tensor(output) # Ensure only one tensor + + N = output.shape[-3] + if N == 1: # if the original problem is binary using sigmoid, change to one-hot style. + output = torch.log_softmax([-output, output], dim=-3) + + if self.roi is not None: + output = self.roi.apply_roi(output) + output = torch.sum(output, dim=(-2, -1)) + return output + diff --git a/requirements.txt b/requirements.txt index 9f05c345..f4d4350e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ -dataclasses==0.8 +dataclasses dicom-factory==0.0.3 numpy==1.19.5 Pillow==8.1.1 torch==1.7.1 torchvision==0.8.2 typing-extensions==3.7.4.3 +scikit-image ttach -tqdm \ No newline at end of file +tqdm +opencv-python diff --git a/segcam.py b/segcam.py new file mode 100644 index 00000000..a2a2e916 --- /dev/null +++ b/segcam.py @@ -0,0 +1,156 @@ +import argparse +import cv2 +import numpy as np +import torch +import torch.nn +from torchvision import models +import matplotlib.pyplot as plt + +from pytorch_grad_cam import GradCAM, \ + ScoreCAM, \ + GradCAMPlusPlus, \ + AblationCAM, \ + XGradCAM, \ + EigenCAM, \ + EigenGradCAM + +from pytorch_grad_cam.utils.roi import BaseROI, \ + PixelROI, \ + ClassROI, \ + get_output_tensor, \ + SegModel + +from pytorch_grad_cam import GuidedBackpropReLUModel +from pytorch_grad_cam.utils.image import show_cam_on_image, \ + deprocess_image, \ + preprocess_image + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--use-cuda', action='store_true', default=False, + help='Use NVIDIA GPU acceleration') + parser.add_argument('--image-path', type=str, default='./examples/both.png', + help='Input image path') + parser.add_argument('--aug_smooth', action='store_true', + help='Apply test time augmentation to smooth the CAM') + parser.add_argument('--eigen_smooth', action='store_true', + help='Reduce noise by taking the first principle componenet' + 'of cam_weights*activations') + parser.add_argument('--method', type=str, default='gradcam', + choices=['gradcam', 'gradcam++', 'scorecam', 'xgradcam', + 'ablationcam', 'eigencam', 'eigengradcam'], + help='Can be gradcam/gradcam++/scorecam/xgradcam' + '/ablationcam/eigencam/eigengradcam') + parser.add_argument('--roimode', type=int, default='0') + args = parser.parse_args() + args.use_cuda = args.use_cuda and torch.cuda.is_available() + if args.use_cuda: + print('Using GPU for acceleration') + else: + print('Using CPU for computation') + + return args + + +if __name__ == '__main__': + """ python cam.py -image-path + Example usage of loading an image, and computing: + 1. CAM + 2. Guided Back Propagation + 3. Combining both + """ + + args = get_args() + methods = \ + {"gradcam": GradCAM, + "scorecam": ScoreCAM, + "gradcam++": GradCAMPlusPlus, + "ablationcam": AblationCAM, + "xgradcam": XGradCAM, + "eigencam": EigenCAM, + "eigengradcam": EigenGradCAM} + + model = models.segmentation.fcn_resnet50(pretrained=True) + model.eval() + # Choose the target layer you want to compute the visualization for. + # Usually this will be the last convolutional layer in the model. + # Some common choices can be: + # Resnet18 and 50: model.layer4[-1] + # VGG, densenet161: model.features[-1] + # mnasnet1_0: model.layers[-1] + # You can print the model to help chose the layer + target_layer = model.backbone.layer4[-1] + + rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1] + rgb_img = np.float32(rgb_img) / 255 + input_tensor = preprocess_image(rgb_img, mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + ROIMode = args.roimode + if ROIMode == 0: + ## All pixels + segmodel = SegModel(model, roi=BaseROI(rgb_img)) + elif ROIMode == 1: + ## Single code assigned roi + roi = PixelROI(50, 130, rgb_img) + segmodel = SegModel(model, roi=roi) + elif ROIMode == 2: + ## User pick a pixel + roi = PixelROI(50, 130, rgb_img) + ## Before or after pass to model, both work + # roi.pickPoint() + segmodel = SegModel(model, roi=roi) + roi.pickPixel() + elif ROIMode == 3: + ## Of specific class (GT or predict, depending on what user passes) + pred = torch.argmax(get_output_tensor(model(input_tensor)), -3).squeeze(0) + roi = ClassROI(rgb_img, pred, 12) + # roi.largestComponent() + # roi.smallestComponent() + # roi.pickClass() + roi.pickComponentClass() + segmodel = SegModel(model, roi=roi) + + + cam = methods[args.method](model=segmodel, + target_layer=target_layer, + use_cuda=args.use_cuda) + + # If None, returns the map for the highest scoring category. + # Otherwise, targets the requested category. + target_category = None + + # AblationCAM and ScoreCAM have batched implementations. + # You can override the internal batch size for faster computation. + cam.batch_size = 32 + + grayscale_cam = cam(input_tensor=input_tensor, + target_category=target_category, + aug_smooth=args.aug_smooth, + eigen_smooth=args.eigen_smooth) + + # Here grayscale_cam has only one image in the batch + grayscale_cam = grayscale_cam[0, :] + + cam_image = show_cam_on_image(rgb_img, grayscale_cam) + + gb_model = GuidedBackpropReLUModel(model=segmodel, use_cuda=args.use_cuda) + gb = gb_model(input_tensor, target_category=target_category) + + cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam]) + cam_gb = deprocess_image(cam_mask * gb) + gb = deprocess_image(gb) + + if True: + plt.figure() + plt.imshow(cam_image) + # plt.figure() + # plt.imshow(gb) + plt.figure() + plt.imshow(cam_gb) + plt.show() + else: + cv2.imwrite(f'{args.method}_cam.jpg', cam_image) + cv2.imwrite(f'{args.method}_gb.jpg', gb) + cv2.imwrite(f'{args.method}_cam_gb.jpg', cam_gb)