From 64728aea9e529276530a1c7a032ac152d21a5038 Mon Sep 17 00:00:00 2001 From: Shaun Song Date: Tue, 25 May 2021 01:12:29 -0400 Subject: [PATCH 1/4] add segmentation network adaptor --- requirements.txt | 5 +- segcam.py | 222 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 225 insertions(+), 2 deletions(-) create mode 100644 segcam.py diff --git a/requirements.txt b/requirements.txt index 9f05c345..dbc2a819 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -dataclasses==0.8 +dataclasses dicom-factory==0.0.3 numpy==1.19.5 Pillow==8.1.1 @@ -6,4 +6,5 @@ torch==1.7.1 torchvision==0.8.2 typing-extensions==3.7.4.3 ttach -tqdm \ No newline at end of file +tqdm +opencv-python diff --git a/segcam.py b/segcam.py new file mode 100644 index 00000000..c30de859 --- /dev/null +++ b/segcam.py @@ -0,0 +1,222 @@ +import argparse +import collections + +import cv2 +import numpy as np +import torch +import torch.nn +from torchvision import models +import matplotlib.pyplot as plt + +from pytorch_grad_cam import GradCAM, \ + ScoreCAM, \ + GradCAMPlusPlus, \ + AblationCAM, \ + XGradCAM, \ + EigenCAM, \ + EigenGradCAM + +from pytorch_grad_cam import GuidedBackpropReLUModel +from pytorch_grad_cam.utils.image import show_cam_on_image, \ + deprocess_image, \ + preprocess_image + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--use-cuda', action='store_true', default=False, + help='Use NVIDIA GPU acceleration') + parser.add_argument('--image-path', type=str, default='./examples/both.png', + help='Input image path') + parser.add_argument('--aug_smooth', action='store_true', + help='Apply test time augmentation to smooth the CAM') + parser.add_argument('--eigen_smooth', action='store_true', + help='Reduce noise by taking the first principle componenet' + 'of cam_weights*activations') + parser.add_argument('--method', type=str, default='gradcam', + choices=['gradcam', 'gradcam++', 'scorecam', 'xgradcam', + 'ablationcam', 'eigencam', 'eigengradcam'], + help='Can be gradcam/gradcam++/scorecam/xgradcam' + '/ablationcam/eigencam/eigengradcam') + + args = parser.parse_args() + args.use_cuda = args.use_cuda and torch.cuda.is_available() + if args.use_cuda: + print('Using GPU for acceleration') + else: + print('Using CPU for computation') + + return args + + +class BaseROI: + def __init__(self, image = None): + self.image = image + self.roi = 1 + self.fullroi = None + self.i = None + self.j = None + + def setROIij(self): + print(f'Shape of ROI:{self.roi.shape}') + self.i = np.where(self.roi == 1)[0] + self.j = np.where(self.roi == 1)[1] + print(f'Lengths of i and j index lists: {len(self.i)}, {len(self.j)}') + + def meshgrid(self): + ylist = np.linspace(0, self.image.shape[0], self.image.shape[0]) + xlist = np.linspace(0, self.image.shape[1], self.image.shape[1]) + return np.meshgrid(xlist, ylist) + +class PixelROI(BaseROI): + def __init__(self, i, j, image): + self.image = image + self.roi = torch.zeros((image.shape[-3], image.shape[-2])) + self.roi[i, j] = 1 + self.i = i + self.j = j +# +# class ClassROI(BaseROI): +# def __init__(self, model, image, cls): +# preds = model.predict(np.expand_dims(image, 0))[0] +# max_preds = preds.argmax(axis=-1) +# self.image = image +# self.roi = np.round(preds[..., cls] * (max_preds == cls)).reshape(image.shape[-3], image.shape[-2]) +# self.fullroi = self.roi +# self.setROIij() +# +# def connectedComponents(self, ignore=None): +# _, all_labels = cv2.connectedComponents(self.fullroi) +# # all_labels = measure.label(self.fullroi, background=0) +# +# +# (values, counts) = np.unique(all_labels * (all_labels != 0), return_counts=True) +# print("connectedComponents values, counts: ", values, counts) +# return all_labels, values, counts +# +# def largestComponent(self): +# all_labels, values, counts = self.connectedComponents() +# # find the largest component +# ind = np.argmax(counts[values != 0]) + 1 # +1 because indexing starts from 0 for the background +# print("argmax: ", ind) +# # define RoI +# self.roi = (all_labels == ind).astype(int) +# self.setRoIij() +# +# def smallestComponent(self): +# all_labels, values, counts = self.connectedComponents() +# ind = np.argmin(counts[values != 0]) + 1 +# print("argmin: ", ind) # +# self.roi = (all_labels == ind).astype(int) +# self.setRoIij() + + +def get_output_tensor(output, verbose=True): + if isinstance(output, torch.Tensor): + return output + elif isinstance(output, collections.OrderedDict): + k = next(iter(output.keys())) + if verbose: print(f'Select "{k}" from dict {output.keys()}') + return output[k] + elif isinstance(output, list): + if verbose: print(f'Select "[0]" from list(n={len(output)})') + return output[0] + else: + raise RuntimeError(f'Unknown type {type(output)}') + +class SegModel(torch.nn.Module): + def __init__(self, model, roi=None): + super(SegModel, self).__init__() + self.model = model + self.roi = roi + + def forward(self, x): + output = self.model(x) + output = get_output_tensor(output) + if self.roi is not None: + output = output * self.roi.roi + output = torch.sum(output, dim=(2, 3)) + return output + +if __name__ == '__main__': + """ python cam.py -image-path + Example usage of loading an image, and computing: + 1. CAM + 2. Guided Back Propagation + 3. Combining both + """ + + args = get_args() + methods = \ + {"gradcam": GradCAM, + "scorecam": ScoreCAM, + "gradcam++": GradCAMPlusPlus, + "ablationcam": AblationCAM, + "xgradcam": XGradCAM, + "eigencam": EigenCAM, + "eigengradcam": EigenGradCAM} + + model = models.segmentation.fcn_resnet50(pretrained=True) + + # Choose the target layer you want to compute the visualization for. + # Usually this will be the last convolutional layer in the model. + # Some common choices can be: + # Resnet18 and 50: model.layer4[-1] + # VGG, densenet161: model.features[-1] + # mnasnet1_0: model.layers[-1] + # You can print the model to help chose the layer + target_layer = model.backbone.layer4[-1] + + rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1] + rgb_img = np.float32(rgb_img) / 255 + input_tensor = preprocess_image(rgb_img, mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + segmodel = SegModel(model, roi=PixelROI(0, 130, rgb_img)) + + cam = methods[args.method](model=segmodel, + target_layer=target_layer, + use_cuda=args.use_cuda) + + + modeloutput = torch.argmax(get_output_tensor(model(input_tensor)), dim=1).squeeze(0) + + plt.matshow(modeloutput.cpu().numpy()) + plt.show() + + # If None, returns the map for the highest scoring category. + # Otherwise, targets the requested category. + target_category = 8 + + # AblationCAM and ScoreCAM have batched implementations. + # You can override the internal batch size for faster computation. + cam.batch_size = 32 + + grayscale_cam = cam(input_tensor=input_tensor, + target_category=target_category, + aug_smooth=args.aug_smooth, + eigen_smooth=args.eigen_smooth) + + # Here grayscale_cam has only one image in the batch + grayscale_cam = grayscale_cam[0, :] + + cam_image = show_cam_on_image(rgb_img, grayscale_cam) + + gb_model = GuidedBackpropReLUModel(model=segmodel, use_cuda=args.use_cuda) + gb = gb_model(input_tensor, target_category=target_category) + + cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam]) + cam_gb = deprocess_image(cam_mask * gb) + gb = deprocess_image(gb) + + plt.figure() + plt.imshow(cam_image) + plt.figure() + plt.imshow(gb) + plt.figure() + plt.imshow(cam_gb) + plt.show() + + # cv2.imwrite(f'{args.method}_cam.jpg', cam_image) + # cv2.imwrite(f'{args.method}_gb.jpg', gb) + # cv2.imwrite(f'{args.method}_cam_gb.jpg', cam_gb) From fee8695b35d1e61711336917fd555a179873873f Mon Sep 17 00:00:00 2001 From: Shaun Song Date: Wed, 26 May 2021 18:57:44 -0400 Subject: [PATCH 2/4] Implement SegGradCAM to adapt Semantic Segmentation networks. The origin work https://github.com/kiraving/SegGradCAM --- pytorch_grad_cam/utils/roi.py | 123 +++++++++++++++++++++++++++++ requirements.txt | 1 + segcam.py | 144 +++++++++++++--------------------- 3 files changed, 179 insertions(+), 89 deletions(-) create mode 100644 pytorch_grad_cam/utils/roi.py diff --git a/pytorch_grad_cam/utils/roi.py b/pytorch_grad_cam/utils/roi.py new file mode 100644 index 00000000..1c3eee1a --- /dev/null +++ b/pytorch_grad_cam/utils/roi.py @@ -0,0 +1,123 @@ +import numpy as np +import torch +import matplotlib.pyplot as plt +from skimage import measure + +class BaseROI: + def __init__(self, image = None): + self.image = image + self.roi = 1 + self.fullroi = None + self.i = None + self.j = None + + def setROIij(self): + print(f'Shape of ROI:{self.roi.shape}') + self.i = np.where(self.roi == 1)[0] + self.j = np.where(self.roi == 1)[1] + print(f'Lengths of i and j index lists: {len(self.i)}, {len(self.j)}') + + def meshgrid(self): + ylist = np.linspace(0, self.image.shape[0], self.image.shape[0]) + xlist = np.linspace(0, self.image.shape[1], self.image.shape[1]) + return np.meshgrid(xlist, ylist) + + def apply_roi(self, output): + return self.roi * output + +def gui_get_point(image, i=None, j=None): + fig = plt.figure('Input Pick An Point') + scale = np.mean(image.shape[:2]) + if len(image.shape)==3: + pImg = plt.imshow(image) + else: + pImg = plt.matshow(image) + + pMarker = plt.scatter(j, i, c='r', s=scale, marker='x') + ret = plt.ginput(1) + if ret == None or ret == []: + pass + else: + j, i = ret[0] + j, i= int(j), int(i) + pMarker.remove() + pMarker = plt.scatter(j, i, c='r', s=scale, marker='x') + plt.close(fig) + return i, j + +class PixelROI(BaseROI): + def __init__(self, i, j, image): + self.image = image + self.roi = torch.zeros((image.shape[-3], image.shape[-2])) + self.roi[i, j] = 1 + self.i = i + self.j = j + + def pickPixel(self): + self.i, self.j = gui_get_point(self.image, self.i, self.j) + self.roi.zero_() + self.roi[self.i, self.j] = 1 + print(f'ROI Point: {self.i},{self.j}') + +def filter_connected_components(values, counts, exclude): + selected_indices=[] + selected_counts=[] + selected_values=[] + for i in range(len(values)): + if values[i] != exclude: + selected_indices.append([i]) + selected_values.append(values[i]) + selected_counts.append(counts[i]) + return selected_indices, selected_values, selected_counts + +class ClassROI(BaseROI): + def __init__(self, image, pred, cls, background=0): + self.image = image + self.pred = pred + self.cls = cls + self.roi = (pred == cls).reshape(image.shape[-3], image.shape[-2]) + self.background = background + print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {self.cls}') + + def connectedComponents(self): + all_labels = measure.label(self.pred) + (values, counts) = np.unique(all_labels, return_counts=True) + print("connectedComponents values, counts: ", values, counts) + return all_labels, values, counts + + def largestComponent(self): + all_labels, values, counts = self.connectedComponents() + # find the largest component + selected_indices, selected_values, selected_counts = filter_connected_components(values, + counts, + self.background) + ind = selected_indices[np.argmax(selected_counts)] + print("largestComponent argmax: ", ind) + self.roi = torch.Tensor(all_labels == ind) + print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {values[ind]}') + + def smallestComponent(self): + all_labels, values, counts = self.connectedComponents() + selected_indices, selected_values, selected_counts = filter_connected_components(values, + counts, + self.background) + ind = selected_indices[np.argmin(selected_counts)] + print("smallestComponent argmin: ", ind) + self.roi = torch.Tensor(all_labels == ind) + print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {values[ind]}') + + def pickClass(self): + i, j = gui_get_point(self.pred) + self.cls = self.pred[i, j] + self.roi = (self.pred == self.cls).reshape(self.image.shape[-3], self.image.shape[-2]) + print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {self.cls}') + + def pickComponentClass(self): + i, j = gui_get_point(self.pred) + all_labels, values, counts = self.connectedComponents() + ind = all_labels[i, j] + self.cls = all_labels[i, j] + self.roi = torch.Tensor(all_labels == ind).reshape(self.image.shape[-3], self.image.shape[-2]) + print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {self.cls}') + + diff --git a/requirements.txt b/requirements.txt index dbc2a819..f4d4350e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ Pillow==8.1.1 torch==1.7.1 torchvision==0.8.2 typing-extensions==3.7.4.3 +scikit-image ttach tqdm opencv-python diff --git a/segcam.py b/segcam.py index c30de859..d866930f 100644 --- a/segcam.py +++ b/segcam.py @@ -16,6 +16,10 @@ EigenCAM, \ EigenGradCAM +from pytorch_grad_cam.utils.roi import BaseROI, \ + PixelROI, \ + ClassROI + from pytorch_grad_cam import GuidedBackpropReLUModel from pytorch_grad_cam.utils.image import show_cam_on_image, \ deprocess_image, \ @@ -38,7 +42,7 @@ def get_args(): 'ablationcam', 'eigencam', 'eigengradcam'], help='Can be gradcam/gradcam++/scorecam/xgradcam' '/ablationcam/eigencam/eigengradcam') - + parser.add_argument('--roimode', type=int, default='0') args = parser.parse_args() args.use_cuda = args.use_cuda and torch.cuda.is_available() if args.use_cuda: @@ -48,70 +52,8 @@ def get_args(): return args - -class BaseROI: - def __init__(self, image = None): - self.image = image - self.roi = 1 - self.fullroi = None - self.i = None - self.j = None - - def setROIij(self): - print(f'Shape of ROI:{self.roi.shape}') - self.i = np.where(self.roi == 1)[0] - self.j = np.where(self.roi == 1)[1] - print(f'Lengths of i and j index lists: {len(self.i)}, {len(self.j)}') - - def meshgrid(self): - ylist = np.linspace(0, self.image.shape[0], self.image.shape[0]) - xlist = np.linspace(0, self.image.shape[1], self.image.shape[1]) - return np.meshgrid(xlist, ylist) - -class PixelROI(BaseROI): - def __init__(self, i, j, image): - self.image = image - self.roi = torch.zeros((image.shape[-3], image.shape[-2])) - self.roi[i, j] = 1 - self.i = i - self.j = j -# -# class ClassROI(BaseROI): -# def __init__(self, model, image, cls): -# preds = model.predict(np.expand_dims(image, 0))[0] -# max_preds = preds.argmax(axis=-1) -# self.image = image -# self.roi = np.round(preds[..., cls] * (max_preds == cls)).reshape(image.shape[-3], image.shape[-2]) -# self.fullroi = self.roi -# self.setROIij() -# -# def connectedComponents(self, ignore=None): -# _, all_labels = cv2.connectedComponents(self.fullroi) -# # all_labels = measure.label(self.fullroi, background=0) -# -# -# (values, counts) = np.unique(all_labels * (all_labels != 0), return_counts=True) -# print("connectedComponents values, counts: ", values, counts) -# return all_labels, values, counts -# -# def largestComponent(self): -# all_labels, values, counts = self.connectedComponents() -# # find the largest component -# ind = np.argmax(counts[values != 0]) + 1 # +1 because indexing starts from 0 for the background -# print("argmax: ", ind) -# # define RoI -# self.roi = (all_labels == ind).astype(int) -# self.setRoIij() -# -# def smallestComponent(self): -# all_labels, values, counts = self.connectedComponents() -# ind = np.argmin(counts[values != 0]) + 1 -# print("argmin: ", ind) # -# self.roi = (all_labels == ind).astype(int) -# self.setRoIij() - - -def get_output_tensor(output, verbose=True): +# Get tensor from output of network. Some segmentation network returns more than 1 tensor. +def get_output_tensor(output, verbose=False): if isinstance(output, torch.Tensor): return output elif isinstance(output, collections.OrderedDict): @@ -131,11 +73,16 @@ def __init__(self, model, roi=None): self.roi = roi def forward(self, x): - output = self.model(x) - output = get_output_tensor(output) + output = self.model(x) # might be multiple tensors + output = get_output_tensor(output) # Ensure only one tensor + + N = output.shape[-3] + if N == 1: # if the original problem is binary using sigmoid, change to one-hot style. + output = torch.log_softmax([-output, output], dim=-3) + if self.roi is not None: - output = output * self.roi.roi - output = torch.sum(output, dim=(2, 3)) + output = self.roi.apply_roi(output) + output = torch.sum(output, dim=(-2, -1)) return output if __name__ == '__main__': @@ -157,7 +104,7 @@ def forward(self, x): "eigengradcam": EigenGradCAM} model = models.segmentation.fcn_resnet50(pretrained=True) - + model.eval() # Choose the target layer you want to compute the visualization for. # Usually this will be the last convolutional layer in the model. # Some common choices can be: @@ -172,21 +119,39 @@ def forward(self, x): input_tensor = preprocess_image(rgb_img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - segmodel = SegModel(model, roi=PixelROI(0, 130, rgb_img)) + ROIMode = args.roimode + if ROIMode == 0: + ## All pixels + segmodel = SegModel(model, roi=BaseROI(rgb_img)) + elif ROIMode == 1: + ## Single code assigned roi + roi = PixelROI(50, 130, rgb_img) + segmodel = SegModel(model, roi=roi) + elif ROIMode == 2: + ## User pick a pixel + roi = PixelROI(50, 130, rgb_img) + ## Before or after pass to model, both work + # roi.pickPoint() + segmodel = SegModel(model, roi=roi) + roi.pickROI() + elif ROIMode == 3: + ## Of specific class (GT or predict, depending on what user passes) + pred = torch.argmax(get_output_tensor(model(input_tensor)), -3).squeeze(0) + roi = ClassROI(rgb_img, pred, 12) + # roi.largestComponent() + # roi.smallestComponent() + # roi.pickClass() + roi.pickComponentClass() + segmodel = SegModel(model, roi=roi) + cam = methods[args.method](model=segmodel, target_layer=target_layer, use_cuda=args.use_cuda) - - modeloutput = torch.argmax(get_output_tensor(model(input_tensor)), dim=1).squeeze(0) - - plt.matshow(modeloutput.cpu().numpy()) - plt.show() - # If None, returns the map for the highest scoring category. # Otherwise, targets the requested category. - target_category = 8 + target_category = None # AblationCAM and ScoreCAM have batched implementations. # You can override the internal batch size for faster computation. @@ -209,14 +174,15 @@ def forward(self, x): cam_gb = deprocess_image(cam_mask * gb) gb = deprocess_image(gb) - plt.figure() - plt.imshow(cam_image) - plt.figure() - plt.imshow(gb) - plt.figure() - plt.imshow(cam_gb) - plt.show() - - # cv2.imwrite(f'{args.method}_cam.jpg', cam_image) - # cv2.imwrite(f'{args.method}_gb.jpg', gb) - # cv2.imwrite(f'{args.method}_cam_gb.jpg', cam_gb) + if True: + plt.figure() + plt.imshow(cam_image) + # plt.figure() + # plt.imshow(gb) + plt.figure() + plt.imshow(cam_gb) + plt.show() + else: + cv2.imwrite(f'{args.method}_cam.jpg', cam_image) + cv2.imwrite(f'{args.method}_gb.jpg', gb) + cv2.imwrite(f'{args.method}_cam_gb.jpg', cam_gb) From 605eb169e0a7f9fcc07c5d2c1736b79d44131613 Mon Sep 17 00:00:00 2001 From: Shaun Song Date: Wed, 26 May 2021 19:03:58 -0400 Subject: [PATCH 3/4] Move SegModel to src/utils --- pytorch_grad_cam/utils/roi.py | 72 ++++++++++++++++++++++++++--------- segcam.py | 38 ++---------------- 2 files changed, 56 insertions(+), 54 deletions(-) diff --git a/pytorch_grad_cam/utils/roi.py b/pytorch_grad_cam/utils/roi.py index 1c3eee1a..547d170e 100644 --- a/pytorch_grad_cam/utils/roi.py +++ b/pytorch_grad_cam/utils/roi.py @@ -2,6 +2,27 @@ import torch import matplotlib.pyplot as plt from skimage import measure +import collections + +def gui_get_point(image, i=None, j=None): + fig = plt.figure('Input Pick An Point') + scale = np.mean(image.shape[:2]) + if len(image.shape)==3: + pImg = plt.imshow(image) + else: + pImg = plt.matshow(image) + + pMarker = plt.scatter(j, i, c='r', s=scale, marker='x') + ret = plt.ginput(1) + if ret == None or ret == []: + pass + else: + j, i = ret[0] + j, i= int(j), int(i) + pMarker.remove() + pMarker = plt.scatter(j, i, c='r', s=scale, marker='x') + plt.close(fig) + return i, j class BaseROI: def __init__(self, image = None): @@ -25,25 +46,6 @@ def meshgrid(self): def apply_roi(self, output): return self.roi * output -def gui_get_point(image, i=None, j=None): - fig = plt.figure('Input Pick An Point') - scale = np.mean(image.shape[:2]) - if len(image.shape)==3: - pImg = plt.imshow(image) - else: - pImg = plt.matshow(image) - - pMarker = plt.scatter(j, i, c='r', s=scale, marker='x') - ret = plt.ginput(1) - if ret == None or ret == []: - pass - else: - j, i = ret[0] - j, i= int(j), int(i) - pMarker.remove() - pMarker = plt.scatter(j, i, c='r', s=scale, marker='x') - plt.close(fig) - return i, j class PixelROI(BaseROI): def __init__(self, i, j, image): @@ -120,4 +122,36 @@ def pickComponentClass(self): self.roi = torch.Tensor(all_labels == ind).reshape(self.image.shape[-3], self.image.shape[-2]) print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {self.cls}') +# Get tensor from output of network. Some segmentation network returns more than 1 tensor. +def get_output_tensor(output, verbose=False): + if isinstance(output, torch.Tensor): + return output + elif isinstance(output, collections.OrderedDict): + k = next(iter(output.keys())) + if verbose: print(f'Select "{k}" from dict {output.keys()}') + return output[k] + elif isinstance(output, list): + if verbose: print(f'Select "[0]" from list(n={len(output)})') + return output[0] + else: + raise RuntimeError(f'Unknown type {type(output)}') + +class SegModel(torch.nn.Module): + def __init__(self, model, roi=None): + super(SegModel, self).__init__() + self.model = model + self.roi = roi + + def forward(self, x): + output = self.model(x) # might be multiple tensors + output = get_output_tensor(output) # Ensure only one tensor + + N = output.shape[-3] + if N == 1: # if the original problem is binary using sigmoid, change to one-hot style. + output = torch.log_softmax([-output, output], dim=-3) + + if self.roi is not None: + output = self.roi.apply_roi(output) + output = torch.sum(output, dim=(-2, -1)) + return output diff --git a/segcam.py b/segcam.py index d866930f..a901f037 100644 --- a/segcam.py +++ b/segcam.py @@ -1,6 +1,4 @@ import argparse -import collections - import cv2 import numpy as np import torch @@ -18,7 +16,9 @@ from pytorch_grad_cam.utils.roi import BaseROI, \ PixelROI, \ - ClassROI + ClassROI, \ + get_output_tensor, \ + SegModel from pytorch_grad_cam import GuidedBackpropReLUModel from pytorch_grad_cam.utils.image import show_cam_on_image, \ @@ -52,38 +52,6 @@ def get_args(): return args -# Get tensor from output of network. Some segmentation network returns more than 1 tensor. -def get_output_tensor(output, verbose=False): - if isinstance(output, torch.Tensor): - return output - elif isinstance(output, collections.OrderedDict): - k = next(iter(output.keys())) - if verbose: print(f'Select "{k}" from dict {output.keys()}') - return output[k] - elif isinstance(output, list): - if verbose: print(f'Select "[0]" from list(n={len(output)})') - return output[0] - else: - raise RuntimeError(f'Unknown type {type(output)}') - -class SegModel(torch.nn.Module): - def __init__(self, model, roi=None): - super(SegModel, self).__init__() - self.model = model - self.roi = roi - - def forward(self, x): - output = self.model(x) # might be multiple tensors - output = get_output_tensor(output) # Ensure only one tensor - - N = output.shape[-3] - if N == 1: # if the original problem is binary using sigmoid, change to one-hot style. - output = torch.log_softmax([-output, output], dim=-3) - - if self.roi is not None: - output = self.roi.apply_roi(output) - output = torch.sum(output, dim=(-2, -1)) - return output if __name__ == '__main__': """ python cam.py -image-path From 99b1664d57db5f9173c6491f6b47510f957e3100 Mon Sep 17 00:00:00 2001 From: Shaun Song Date: Fri, 28 May 2021 12:13:48 -0400 Subject: [PATCH 4/4] fixed function not found in ROImode=2 fixed incompatibility of --use-cuda --- pytorch_grad_cam/utils/roi.py | 4 ++-- segcam.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_grad_cam/utils/roi.py b/pytorch_grad_cam/utils/roi.py index 547d170e..aed8e60a 100644 --- a/pytorch_grad_cam/utils/roi.py +++ b/pytorch_grad_cam/utils/roi.py @@ -27,7 +27,7 @@ def gui_get_point(image, i=None, j=None): class BaseROI: def __init__(self, image = None): self.image = image - self.roi = 1 + self.roi = torch.Tensor([1]) self.fullroi = None self.i = None self.j = None @@ -44,7 +44,7 @@ def meshgrid(self): return np.meshgrid(xlist, ylist) def apply_roi(self, output): - return self.roi * output + return self.roi.to(output.device) * output class PixelROI(BaseROI): diff --git a/segcam.py b/segcam.py index a901f037..a2a2e916 100644 --- a/segcam.py +++ b/segcam.py @@ -101,7 +101,7 @@ def get_args(): ## Before or after pass to model, both work # roi.pickPoint() segmodel = SegModel(model, roi=roi) - roi.pickROI() + roi.pickPixel() elif ROIMode == 3: ## Of specific class (GT or predict, depending on what user passes) pred = torch.argmax(get_output_tensor(model(input_tensor)), -3).squeeze(0)