diff --git a/common_blocks/callbacks.py b/common_blocks/callbacks.py index 83057e1..37702d5 100644 --- a/common_blocks/callbacks.py +++ b/common_blocks/callbacks.py @@ -7,7 +7,7 @@ from PIL import Image import neptune from torch.autograd import Variable -from torch.optim.lr_scheduler import ExponentialLR +from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau from tempfile import TemporaryDirectory from steppy.base import Step, IdentityOperation @@ -200,6 +200,83 @@ def on_batch_end(self, *args, **kwargs): self.batch_id += 1 +class ReduceLROnPlateauScheduler(Callback): + def __init__(self, metric_name, minimize, reduce_factor, reduce_patience, min_lr): + super().__init__() + self.ctx = neptune.Context() + self.metric_name = metric_name + self.minimize = minimize + self.reduce_factor = reduce_factor + self.reduce_patience = reduce_patience + self.min_lr = min_lr + + def set_params(self, transformer, validation_datagen, *args, **kwargs): + super().set_params(transformer, validation_datagen) + self.validation_datagen = validation_datagen + self.model = transformer.model + self.optimizer = transformer.optimizer + self.loss_function = transformer.loss_function + self.lr_scheduler = ReduceLROnPlateau(optimizer=self.optimizer, + mode='min' if self.minimize else 'max', + factor=self.reduce_factor, + patience=self.reduce_patience, + min_lr=self.min_lr) + + def on_train_begin(self, *args, **kwargs): + self.epoch_id = 0 + self.batch_id = 0 + + def on_epoch_end(self, *args, **kwargs): + self.model.eval() + val_loss = self.get_validation_loss() + metric = val_loss[self.metric_name] + metric = metric.data.cpu().numpy()[0] + self.model.train() + + self.lr_scheduler.step(metrics=metric, epoch=self.epoch_id) + logger.info('epoch {0} current lr: {1}'.format(self.epoch_id + 1, + self.optimizer.state_dict()['param_groups'][0]['lr'])) + self.ctx.channel_send('Learning Rate', x=self.epoch_id, + y=self.optimizer.state_dict()['param_groups'][0]['lr']) + + self.epoch_id += 1 + + +class InitialLearningRateFinder(Callback): + def __init__(self, min_lr=1e-8, multipy_factor=1.05, add_factor=0.0): + super().__init__() + self.ctx = neptune.Context() + self.min_lr = min_lr + self.multipy_factor = multipy_factor + self.add_factor = add_factor + + def set_params(self, transformer, validation_datagen, *args, **kwargs): + super().set_params(transformer, validation_datagen) + self.validation_datagen = validation_datagen + self.model = transformer.model + self.optimizer = transformer.optimizer + self.loss_function = transformer.loss_function + + def on_train_begin(self, *args, **kwargs): + self.epoch_id = 0 + self.batch_id = 0 + + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.min_lr + + def on_batch_end(self, metrics, *args, **kwargs): + for name, loss in metrics.items(): + loss = loss.data.cpu().numpy()[0] + current_lr = self.optimizer.state_dict()['param_groups'][0]['lr'] + logger.info('Learning Rate {} Loss {})'.format(current_lr, loss)) + self.ctx.channel_send('Learning Rate', x=self.batch_id, y=current_lr) + self.ctx.channel_send('Loss', x=self.batch_id, y=loss) + + for param_group in self.optimizer.param_groups: + param_group['lr'] = current_lr * self.multipy_factor + self.add_factor + self.batch_id += 1 + + class ExperimentTiming(Callback): def __init__(self, epoch_every=None, batch_every=None): super().__init__() @@ -340,11 +417,24 @@ def on_epoch_end(self, *args, **kwargs): def _get_validation_loss(self): output, epoch_loss = self._transform() - y_pred = self._generate_prediction(output) + logger.info('Selecting best threshold') + + iout_best, threshold_best = 0.0, 0.5 + for threshold in np.linspace(0.5, 0.3, 21): + y_pred = self._generate_prediction(output, threshold) + iout_score = intersection_over_union_thresholds(self.y_true, y_pred) + logger.info('threshold {} IOUT {}'.format(threshold, iout_score)) + if iout_score > iout_best: + iout_best = iout_score + threshold_best = threshold + else: + break + logger.info('Selected best threshold {} IOUT {}'.format(threshold_best, iout_best)) logger.info('Calculating IOU and IOUT Scores') - iou_score = intersection_over_union(self.y_true, y_pred) + y_pred = self._generate_prediction(output, threshold_best) iout_score = intersection_over_union_thresholds(self.y_true, y_pred) + iou_score = intersection_over_union(self.y_true, y_pred) logger.info('IOU score on validation is {}'.format(iou_score)) logger.info('IOUT score on validation is {}'.format(iout_score)) @@ -407,14 +497,14 @@ def _transform(self): return outputs, average_losses - def _generate_prediction(self, outputs): + def _generate_prediction(self, outputs, threshold): data = {'callback_input': {'meta': self.meta_valid, 'meta_valid': None, }, 'unet_output': {**outputs} } with TemporaryDirectory() as cache_dirpath: - pipeline = self.validation_pipeline(cache_dirpath, self.loader_mode) + pipeline = self.validation_pipeline(cache_dirpath, self.loader_mode, threshold) output = pipeline.transform(data) y_pred = output['y_pred'] return y_pred @@ -494,7 +584,7 @@ def on_epoch_end(self, *args, **kwargs): self.epoch_id += 1 -def postprocessing_pipeline_simplified(cache_dirpath, loader_mode): +def postprocessing_pipeline_simplified(cache_dirpath, loader_mode, threshold): if loader_mode == 'resize_and_pad': size_adjustment_function = partial(crop_image, target_size=ORIGINAL_SIZE) elif loader_mode == 'resize': @@ -513,7 +603,7 @@ def postprocessing_pipeline_simplified(cache_dirpath, loader_mode): binarizer = Step(name='binarizer', transformer=make_apply_transformer( - partial(binarize, threshold=THRESHOLD), + partial(binarize, threshold=threshold), output_name='binarized_images', apply_on=['images']), input_steps=[mask_resize], diff --git a/common_blocks/loaders.py b/common_blocks/loaders.py index 7012bb2..bf821b8 100644 --- a/common_blocks/loaders.py +++ b/common_blocks/loaders.py @@ -14,7 +14,7 @@ import json from steppy.base import BaseTransformer -from .utils import from_pil, to_pil, binary_from_rle, ImgAug +from .utils import from_pil, to_pil, binary_from_rle, ImgAug, AddDepthChannels class ImageReader(BaseTransformer): @@ -337,6 +337,7 @@ def __init__(self, train_mode, loader_params, dataset_params, augmentation_param transforms.ToTensor(), transforms.Normalize(mean=self.dataset_params.MEAN, std=self.dataset_params.STD), + AddDepthChannels() ]) self.mask_transform = transforms.Compose([transforms.Lambda(to_array), transforms.Lambda(to_tensor), @@ -364,6 +365,7 @@ def __init__(self, loader_params, dataset_params, augmentation_params): transforms.ToTensor(), transforms.Normalize(mean=self.dataset_params.MEAN, std=self.dataset_params.STD), + AddDepthChannels() ]) self.mask_transform = transforms.Compose([transforms.Lambda(to_array), transforms.Lambda(to_tensor), diff --git a/common_blocks/models.py b/common_blocks/models.py index 77ba477..2b7fe24 100644 --- a/common_blocks/models.py +++ b/common_blocks/models.py @@ -1,45 +1,68 @@ +from functools import partial + import numpy as np import torch import torch.optim as optim from torch.autograd import Variable import torch.nn as nn -from functools import partial +from torch.nn import functional as F from toolkit.pytorch_transformers.models import Model from .utils import sigmoid, softmax, get_list_of_image_predictions, pytorch_where from . import callbacks as cbk -from .unet_models import UNetResNet, SaltUNet, SaltLinkNet +from .unet_models import UNetResNet from .lovasz_losses import lovasz_hinge -PRETRAINED_NETWORKS = {'ResNet34': {'model': UNetResNet, - 'model_config': {'encoder_depth': 34, - 'num_filters': 32, 'dropout_2d': 0.0, - 'pretrained': True, 'is_deconv': True, +PRETRAINED_NETWORKS = {'ResNet18': {'model': UNetResNet, + 'model_config': {'encoder_depth': 18, 'use_hypercolumn': False, + 'dropout_2d': 0.0, 'pretrained': True, + }, + 'init_weights': False}, + 'ResNet34': {'model': UNetResNet, + 'model_config': {'encoder_depth': 34, 'use_hypercolumn': False, + 'dropout_2d': 0.0, 'pretrained': True, + }, + 'init_weights': False}, + 'ResNet50': {'model': UNetResNet, + 'model_config': {'encoder_depth': 50, 'use_hypercolumn': False, + 'dropout_2d': 0.0, 'pretrained': True, }, 'init_weights': False}, 'ResNet101': {'model': UNetResNet, - 'model_config': {'encoder_depth': 101, - 'num_filters': 32, 'dropout_2d': 0.0, - 'pretrained': True, 'is_deconv': True, + 'model_config': {'encoder_depth': 101, 'use_hypercolumn': False, + 'dropout_2d': 0.0, 'pretrained': True, }, 'init_weights': False}, 'ResNet152': {'model': UNetResNet, - 'model_config': {'encoder_depth': 152, - 'num_filters': 32, 'dropout_2d': 0.0, - 'pretrained': True, 'is_deconv': True, + 'model_config': {'encoder_depth': 152, 'use_hypercolumn': False, + 'dropout_2d': 0.0, 'pretrained': True, }, 'init_weights': False}, - 'SaltLinkNet': {'model': SaltLinkNet, - 'model_config': {'dropout_2d': 0.5, - 'pretrained': True, 'is_deconv': True, - }, - 'init_weights': False}, - - 'SaltUNet': {'model': SaltUNet, - 'model_config': {'dropout_2d': 0.5, - 'pretrained': True, 'is_deconv': True, - }, - 'init_weights': False}, + 'ResNetHyper18': {'model': UNetResNet, + 'model_config': {'encoder_depth': 18, 'use_hypercolumn': True, + 'dropout_2d': 0.0, 'pretrained': True, + }, + 'init_weights': False}, + 'ResNetHyper34': {'model': UNetResNet, + 'model_config': {'encoder_depth': 34, 'use_hypercolumn': True, + 'dropout_2d': 0.0, 'pretrained': True, + }, + 'init_weights': False}, + 'ResNetHyper50': {'model': UNetResNet, + 'model_config': {'encoder_depth': 50, 'use_hypercolumn': True, + 'dropout_2d': 0.0, 'pretrained': True, + }, + 'init_weights': False}, + 'ResNetHyper101': {'model': UNetResNet, + 'model_config': {'encoder_depth': 101, 'use_hypercolumn': True, + 'dropout_2d': 0.0, 'pretrained': True, + }, + 'init_weights': False}, + 'ResNetHyper152': {'model': UNetResNet, + 'model_config': {'encoder_depth': 152, 'use_hypercolumn': True, + 'dropout_2d': 0.0, 'pretrained': True, + }, + 'init_weights': False}, } @@ -108,6 +131,7 @@ def _fit_loop(self, data): partial_batch_losses[name] = loss_function(output, target) * weight batch_loss = sum(partial_batch_losses.values()) partial_batch_losses['sum'] = batch_loss + batch_loss.backward() self.optimizer.step() @@ -140,6 +164,7 @@ def _transform(self, datagen, validation_datagen=None, **kwargs): else: X = Variable(X, volatile=True) outputs_batch = self.model(X) + if len(self.output_names) == 1: outputs.setdefault(self.output_names[0], []).append(outputs_batch.data.cpu().numpy()) else: @@ -162,22 +187,9 @@ def set_model(self): def set_loss(self): if self.activation_func == 'softmax': - loss_function = partial(mixed_dice_cross_entropy_loss, - dice_loss=multiclass_dice_loss, - cross_entropy_loss=nn.CrossEntropyLoss(), - dice_activation='softmax', - dice_weight=self.architecture_config['model_params']['dice_weight'], - cross_entropy_weight=self.architecture_config['model_params']['bce_weight'] - ) + raise NotImplementedError('No softmax loss defined') elif self.activation_func == 'sigmoid': loss_function = lovasz_loss - # loss_function = partial(mixed_dice_bce_loss, - # dice_loss=multiclass_dice_loss, - # bce_loss=nn.BCEWithLogitsLoss(), - # dice_activation='sigmoid', - # dice_weight=self.architecture_config['model_params']['dice_weight'], - # bce_weight=self.architecture_config['model_params']['bce_weight'] - # ) else: raise Exception('Only softmax and sigmoid activations are allowed') self.loss_function = [('mask', loss_function, 1.0)] @@ -197,6 +209,35 @@ def load(self, filepath): return self +class FocalWithLogitsLoss(nn.Module): + def __init__(self, alpha=1.0, gamma=1.0): + super().__init__() + self.alpha = alpha + self.gamma = gamma + + def forward(self, input, target): + if not (target.size() == input.size()): + raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size())) + + max_val = (-input).clamp(min=0) + logpt = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() + pt = torch.exp(-logpt) + at = self.alpha * target + (1 - target) + loss = at * ((1 - pt).pow(self.gamma)) * logpt + return loss + + +class DiceLoss(nn.Module): + def __init__(self, smooth=0, eps=1e-7): + super().__init__() + self.smooth = smooth + self.eps = eps + + def forward(self, output, target): + return 1 - (2 * torch.sum(output * target) + self.smooth) / ( + torch.sum(output) + torch.sum(target) + self.smooth + self.eps) + + def weight_regularization(model, regularize, weight_decay_conv2d): if regularize: parameter_list = [ @@ -211,7 +252,7 @@ def weight_regularization(model, regularize, weight_decay_conv2d): def callbacks_unet(callbacks_config): experiment_timing = cbk.ExperimentTiming(**callbacks_config['experiment_timing']) model_checkpoints = cbk.ModelCheckpoint(**callbacks_config['model_checkpoint']) - lr_scheduler = cbk.ExponentialLRScheduler(**callbacks_config['lr_scheduler']) + lr_scheduler = cbk.ReduceLROnPlateauScheduler(**callbacks_config['reduce_lr_on_plateau_scheduler']) training_monitor = cbk.TrainingMonitor(**callbacks_config['training_monitor']) validation_monitor = cbk.ValidationMonitor(**callbacks_config['validation_monitor']) neptune_monitor = cbk.NeptuneMonitor(**callbacks_config['neptune_monitor']) @@ -222,15 +263,10 @@ def callbacks_unet(callbacks_config): model_checkpoints, lr_scheduler, neptune_monitor, early_stopping]) -class DiceLoss(nn.Module): - def __init__(self, smooth=0, eps=1e-7): - super(DiceLoss, self).__init__() - self.smooth = smooth - self.eps = eps - - def forward(self, output, target): - return 1 - (2 * torch.sum(output * target) + self.smooth) / ( - torch.sum(output) + torch.sum(target) + self.smooth + self.eps) +def weighted_lovash_focal_loss(output, target): + focal = weighted_focal_loss(output, target) + lovasz = lovasz_hinge(output, target) + return 0.25 * focal + lovasz def lovasz_loss(output, target): @@ -238,66 +274,72 @@ def lovasz_loss(output, target): return lovasz_hinge(output, target) -def mixed_dice_bce_loss(output, target, dice_weight=0.2, dice_loss=None, - bce_weight=0.9, bce_loss=None, - smooth=0, dice_activation='sigmoid'): - num_classes = output.size(1) - target = target[:, :num_classes, :, :].long() - if bce_loss is None: - bce_loss = nn.BCEWithLogitsLoss() - if dice_loss is None: - dice_loss = multiclass_dice_loss - return dice_weight * dice_loss(output, target, smooth, dice_activation) + bce_weight * bce_loss(output, target) - - -def mixed_dice_cross_entropy_loss(output, target, dice_weight=0.5, dice_loss=None, - cross_entropy_weight=0.5, cross_entropy_loss=None, smooth=0, - dice_activation='softmax'): - num_classes_without_background = output.size(1) - 1 - dice_output = output[:, 1:, :, :] - dice_target = target[:, :num_classes_without_background, :, :].long() - cross_entropy_target = torch.zeros_like(target[:, 0, :, :]).long() - for class_nr in range(num_classes_without_background): - cross_entropy_target = where(target[:, class_nr, :, :], class_nr + 1, cross_entropy_target) - if cross_entropy_loss is None: - cross_entropy_loss = nn.CrossEntropyLoss() - if dice_loss is None: - dice_loss = multiclass_dice_loss - return dice_weight * dice_loss(dice_output, dice_target, smooth, - dice_activation) + cross_entropy_weight * cross_entropy_loss(output, - cross_entropy_target) - - -def multiclass_dice_loss(output, target, smooth=0, activation='softmax'): - """Calculate Dice Loss for multiple class output. - - Args: - output (torch.Tensor): Model output of shape (N x C x H x W). - target (torch.Tensor): Target of shape (N x H x W). - smooth (float, optional): Smoothing factor. Defaults to 0. - activation (string, optional): Name of the activation function, softmax or sigmoid. Defaults to 'softmax'. - - Returns: - torch.Tensor: Loss value. - - """ - if activation == 'softmax': - activation_nn = torch.nn.Softmax2d() - elif activation == 'sigmoid': - activation_nn = torch.nn.Sigmoid() +def weighted_focal_loss(output, target, + alpha=1.0, gamma=5.0, + max_weight=100.0, + focus_threshold=0.1, + use_size_weight=True, + use_border_weight=True, border_size=24, border_weight=2.0 + ): + output = focus_output(output, focus_threshold=focus_threshold) + loss_per_pixel = FocalWithLogitsLoss(alpha=alpha, gamma=gamma)(output, target) + weights = get_weights(target, + max_weight=max_weight, + use_size_weight=use_size_weight, + use_border_weight=use_border_weight, border_size=border_size, border_weight=border_weight) + loss = torch.mean(loss_per_pixel * weights) + return loss + + +def focus_output(output, focus_threshold): + if torch.cuda.is_available(): + output_numpy = F.sigmoid(output).data.cpu().numpy() + else: + output_numpy = F.sigmoid(output).data.numpy() + focus_weights = np.where(output_numpy < focus_threshold, 0.0, 1.0) + focus_weights = Variable(torch.Tensor(focus_weights), requires_grad=False) + if torch.cuda.is_available(): + focus_weights = focus_weights.cuda() + return torch.mul(focus_weights, output) + + +def get_weights(target, max_weight=5.0, + use_size_weight=True, + use_border_weight=True, border_size=10, border_weight=2.0): + if torch.cuda.is_available(): + target_numpy = target.data.cpu().numpy() else: - raise NotImplementedError('only sigmoid and softmax are implemented') + target_numpy = target.data.numpy() + + if use_size_weight: + size_weights = _size_weights(target_numpy) + else: + size_weights = np.ones_like(target_numpy) + + if use_border_weight: + border_weights = _border_weights(target_numpy, border_size=border_size, border_weight=border_weight) + else: + border_weights = np.ones_like(target_numpy) + + weights = border_weights * size_weights + weights = np.where(weights > max_weight, max_weight, weights) + weights = Variable(torch.Tensor(weights), requires_grad=False) + + if torch.cuda.is_available(): + weights = weights.cuda() + return weights + - loss = 0 - dice = DiceLoss(smooth=smooth) - output = activation_nn(output) - num_classes = output.size(1) - target.data = target.data.float() - for class_nr in range(num_classes): - loss += dice(output[:, class_nr, :, :], target[:, class_nr, :, :]) - return loss / num_classes +def _size_weights(target): + target_ = target[:, 1, :, :] + size_per_image = np.mean(target_, axis=(1, 2)) + size_per_image = np.where(size_per_image == 0.0, 1.0, size_per_image) + size_weights_per_image = 1.0 / size_per_image.reshape(-1, 1, 1, 1) + size_weights = np.where(target, np.multiply(target, size_weights_per_image), 1.0) + return size_weights -def where(cond, x_1, x_2): - cond = cond.long() - return (cond * x_1) + ((1 - cond) * x_2) +def _border_weights(target, border_size=10, border_weight=2.0): + border_mask = border_weight * np.ones_like(target) + border_mask[:, :, border_size:-border_size, border_size:-border_size] = 1.0 + return border_mask diff --git a/common_blocks/unet_models.py b/common_blocks/unet_models.py index c4b962e..13d8c16 100644 --- a/common_blocks/unet_models.py +++ b/common_blocks/unet_models.py @@ -21,7 +21,8 @@ class ConvBnRelu(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() - self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1), + self.conv = nn.Sequential(nn.ReplicationPad2d(padding=1), + nn.Conv2d(in_channels, out_channels, 3, padding=0), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) @@ -35,44 +36,59 @@ def forward(self, x): return x -class DecoderBlockV1(nn.Module): +class DecoderBlock(nn.Module): def __init__(self, in_channels, middle_channels, out_channels): - super().__init__() + super(DecoderBlock, self).__init__() + self.conv1 = ConvBnRelu(in_channels, middle_channels) + self.conv2 = ConvBnRelu(middle_channels, out_channels) + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') + self.relu = nn.ReLU(inplace=True) + self.channel_se = ChannelSELayer(out_channels, reduction=16) + self.spatial_se = SpatialSELayer(out_channels) - self.block = nn.Sequential( - ConvBnRelu(in_channels, middle_channels), - nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True) - ) + def forward(self, x, e=None): + x = self.upsample(x) + if e is not None: + x = torch.cat([x, e], 1) + x = self.conv1(x) + x = self.conv2(x) - def forward(self, x): - return self.block(x) + channel_se = self.channel_se(x) + spatial_se = self.spatial_se(x) + x = self.relu(channel_se + spatial_se) + return x -class DecoderBlockV2(nn.Module): - def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True): - super(DecoderBlockV2, self).__init__() - self.is_deconv = is_deconv - self.deconv = nn.Sequential( - ConvBnRelu(in_channels, middle_channels), - nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(out_channels), +class ChannelSELayer(nn.Module): + def __init__(self, channel, reduction=16): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid() ) - self.upsample = nn.Sequential( - ConvBnRelu(in_channels, out_channels), - nn.Upsample(scale_factor=2, mode='bilinear'), - ) + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class SpatialSELayer(nn.Module): + def __init__(self, channels): + super().__init__() + self.fc = nn.Conv2d(channels, 1, kernel_size=1) + self.sigmoid = nn.Sigmoid() def forward(self, x): - if self.is_deconv: - x = self.deconv(x) - else: - x = self.upsample(x) - return x + module_input = x + x = self.fc(x) + x = self.sigmoid(x) + return module_input * x class UNetResNet(nn.Module): @@ -95,18 +111,23 @@ class UNetResNet(nn.Module): False: bilinear interpolation is used in decoder. True: deconvolution is used in decoder. Defaults to False. - """ - def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2, - pretrained=False, is_deconv=False): + def __init__(self, encoder_depth, num_classes, dropout_2d=0.2, pretrained=False, use_hypercolumn=False): super().__init__() self.num_classes = num_classes self.dropout_2d = dropout_2d + self.use_hypercolumn = use_hypercolumn - if encoder_depth == 34: + if encoder_depth == 18: + self.encoder = torchvision.models.resnet18(pretrained=pretrained) + bottom_channel_nr = 512 + elif encoder_depth == 34: self.encoder = torchvision.models.resnet34(pretrained=pretrained) bottom_channel_nr = 512 + elif encoder_depth == 50: + self.encoder = torchvision.models.resnet50(pretrained=pretrained) + bottom_channel_nr = 2048 elif encoder_depth == 101: self.encoder = torchvision.models.resnet101(pretrained=pretrained) bottom_channel_nr = 2048 @@ -114,120 +135,73 @@ def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2, self.encoder = torchvision.models.resnet152(pretrained=pretrained) bottom_channel_nr = 2048 else: - raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented') + raise NotImplementedError('only 18, 34, 50, 101, 152 version of Resnet are implemented') self.pool = nn.MaxPool2d(2, 2) self.relu = nn.ReLU(inplace=True) - self.input_adjust = nn.Sequential(self.encoder.conv1, - self.encoder.bn1, - self.encoder.relu) - - self.conv1 = self.encoder.layer1 - self.conv2 = self.encoder.layer2 - self.conv3 = self.encoder.layer3 - self.conv4 = self.encoder.layer4 - - self.dec4 = DecoderBlockV2(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv) - self.dec3 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, - is_deconv) - self.dec2 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, - is_deconv) - self.dec1 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2, - is_deconv) - self.final = nn.Conv2d(num_filters * 2 * 2, num_classes, kernel_size=1) - - def forward(self, x): - input_adjust = self.input_adjust(x) - conv1 = self.conv1(input_adjust) - conv2 = self.conv2(conv1) - conv3 = self.conv3(conv2) - center = self.conv4(conv3) - dec4 = self.dec4(center) - dec3 = self.dec3(torch.cat([dec4, conv3], 1)) - dec2 = self.dec2(torch.cat([dec3, conv2], 1)) - dec1 = F.dropout2d(self.dec1(torch.cat([dec2, conv1], 1)), p=self.dropout_2d) - return self.final(dec1) - - -class SaltUNet(nn.Module): - def __init__(self, num_classes, dropout_2d=0.2, pretrained=False, is_deconv=False): - super().__init__() - self.num_classes = num_classes - self.dropout_2d = dropout_2d - - self.encoder = torchvision.models.resnet34(pretrained=pretrained) - - self.relu = nn.ReLU(inplace=True) - - self.input_adjust = nn.Sequential(self.encoder.conv1, - self.encoder.bn1, - self.encoder.relu) - - self.conv1 = list(self.encoder.layer1.children())[1] - self.conv2 = list(self.encoder.layer1.children())[2] - self.conv3 = list(self.encoder.layer2.children())[0] - - self.conv4 = list(self.encoder.layer2.children())[1] - - self.dec3 = DecoderBlockV2(256, 512, 256, is_deconv) - self.dec2 = ConvBnRelu(256 + 64, 256) - self.dec1 = DecoderBlockV2(256 + 64, (256 + 64) * 2, 256, is_deconv) - - self.final = nn.Conv2d(256, num_classes, kernel_size=1) - - def forward(self, x): - input_adjust = self.input_adjust(x) - conv1 = self.conv1(input_adjust) - conv2 = self.conv2(conv1) - conv3 = self.conv3(conv2) - center = self.conv4(conv3) - dec3 = self.dec3(torch.cat([center, conv3], 1)) - dec2 = self.dec2(torch.cat([dec3, conv2], 1)) - dec1 = F.dropout2d(self.dec1(torch.cat([dec2, conv1], 1)), p=self.dropout_2d) - return self.final(dec1) - - -class SaltLinkNet(nn.Module): - def __init__(self, num_classes, dropout_2d=0.2, pretrained=False, is_deconv=False): - super().__init__() - self.num_classes = num_classes - self.dropout_2d = dropout_2d - - self.encoder = torchvision.models.resnet34(pretrained=pretrained) - - self.relu = nn.ReLU(inplace=True) - - self.input_adjust = nn.Sequential(self.encoder.conv1, - self.encoder.bn1, - self.encoder.relu) - - self.conv1_1 = list(self.encoder.layer1.children())[1] - self.conv1_2 = list(self.encoder.layer1.children())[2] - - self.conv2_0 = list(self.encoder.layer2.children())[0] - self.conv2_1 = list(self.encoder.layer2.children())[1] - self.conv2_2 = list(self.encoder.layer2.children())[2] - self.conv2_3 = list(self.encoder.layer2.children())[3] - - self.dec2 = DecoderBlockV2(128, 256, 256, is_deconv=is_deconv) - self.dec1 = DecoderBlockV2(256 + 64, 512, 256, is_deconv=is_deconv) - self.final = nn.Conv2d(256, num_classes, kernel_size=1) + self.conv1 = nn.Sequential(self.encoder.conv1, + self.encoder.bn1, + self.encoder.relu) + + self.encoder2 = self.encoder.layer1 + self.encoder3 = self.encoder.layer2 + self.encoder4 = self.encoder.layer3 + self.encoder5 = self.encoder.layer4 + + self.center = nn.Sequential(ConvBnRelu(bottom_channel_nr, bottom_channel_nr), + ConvBnRelu(bottom_channel_nr, bottom_channel_nr // 2), + nn.AvgPool2d(kernel_size=2, stride=2) + ) + + self.dec5 = DecoderBlock(bottom_channel_nr + bottom_channel_nr // 2, + bottom_channel_nr, + bottom_channel_nr // 8) + + self.dec4 = DecoderBlock(bottom_channel_nr // 2 + bottom_channel_nr // 8, + bottom_channel_nr // 2, + bottom_channel_nr // 8) + self.dec3 = DecoderBlock(bottom_channel_nr // 4 + bottom_channel_nr // 8, + bottom_channel_nr // 4, + bottom_channel_nr // 8) + self.dec2 = DecoderBlock(bottom_channel_nr // 8 + bottom_channel_nr // 8, + bottom_channel_nr // 8, + bottom_channel_nr // 8) + self.dec1 = DecoderBlock(bottom_channel_nr // 8, + bottom_channel_nr // 16, + bottom_channel_nr // 8) + + if self.use_hypercolumn: + self.final = nn.Sequential(ConvBnRelu(5 * bottom_channel_nr // 8, bottom_channel_nr // 8), + nn.Conv2d(bottom_channel_nr // 8, num_classes, kernel_size=1, padding=0)) + else: + self.final = nn.Sequential(ConvBnRelu(bottom_channel_nr // 8, bottom_channel_nr // 8), + nn.Conv2d(bottom_channel_nr // 8, num_classes, kernel_size=1, padding=0)) def forward(self, x): - input_adjust = self.input_adjust(x) - conv1_1 = self.conv1_1(input_adjust) - conv1_2 = self.conv1_2(conv1_1) - conv2_0 = self.conv2_0(conv1_2) - conv2_1 = self.conv2_1(conv2_0) - conv2_2 = self.conv2_2(conv2_1) - conv2_3 = self.conv2_3(conv2_2) - - conv1_sum = conv1_1 + conv1_2 - conv2_sum = conv2_0 + conv2_1 + conv2_2 + conv2_3 - - dec2 = self.dec2(conv2_sum) - dec1 = self.dec1(torch.cat([dec2, conv1_sum], 1)) - - return self.final(F.dropout2d(dec1, p=self.dropout_2d)) + conv1 = self.conv1(x) + encoder2 = self.encoder2(conv1) + encoder3 = self.encoder3(encoder2) + encoder4 = self.encoder4(encoder3) + encoder5 = self.encoder5(encoder4) + + center = self.center(encoder5) + + dec5 = self.dec5(center, encoder5) + dec4 = self.dec4(dec5, encoder4) + dec3 = self.dec3(dec4, encoder3) + dec2 = self.dec2(dec3, encoder2) + dec1 = self.dec1(dec2) + + if self.use_hypercolumn: + hypercolumn = torch.cat([dec1, + F.upsample(dec2, scale_factor=2, mode='bilinear'), + F.upsample(dec3, scale_factor=4, mode='bilinear'), + F.upsample(dec4, scale_factor=8, mode='bilinear'), + F.upsample(dec5, scale_factor=16, mode='bilinear'), + ], 1) + drop = F.dropout2d(hypercolumn, p=self.dropout_2d) + else: + drop = F.dropout2d(dec1, p=self.dropout_2d) + return self.final(drop) diff --git a/common_blocks/utils.py b/common_blocks/utils.py index e3b2c8b..407aa49 100644 --- a/common_blocks/utils.py +++ b/common_blocks/utils.py @@ -19,6 +19,7 @@ from sklearn.model_selection import BaseCrossValidator from steppy.base import BaseTransformer, Step from steppy.utils import get_logger +from skimage.transform import resize import yaml from imgaug import augmenters as iaa import imgaug as ia @@ -472,3 +473,69 @@ def _cached_fit_transform(self, step_inputs): def pytorch_where(cond, x_1, x_2): cond = cond.float() return (cond * x_1) + ((1 - cond) * x_2) + + +class AddDepthChannels: + def __call__(self, tensor): + _, h, w = tensor.size() + for row, const in enumerate(np.linspace(0, 1, h)): + tensor[1, row, :] = const + tensor[2] = tensor[0] * tensor[1] + return tensor + + def __repr__(self): + return self.__class__.__name__ + + +def load_image(filepath, is_mask=False): + if is_mask: + img = (np.array(Image.open(filepath)) > 0).astype(np.uint8) + else: + img = np.array(Image.open(filepath)).astype(np.uint8) + return img + + +def save_image(img, filepath): + img = Image.fromarray((img)) + img.save(filepath) + + +def resize_image(image, target_shape, is_mask=False): + if is_mask: + image = (resize(image, target_shape, preserve_range=True) > 0).astype(int) + else: + image = resize(image, target_shape) + return image + + +def get_cut_coordinates(mask, step=4, min_img_crop=20, min_size=50, max_size=300): + h, w = mask.shape + ts = [] + rots = [1, 2, 3, 0] + for rot in rots: + mask = np.rot90(mask) + for t in range(min_img_crop, h, step): + crop = mask[:t, :t] + size = crop.mean() * h * w + if min_size < size <= max_size: + break + ts.append((t, rot)) + try: + ts = [(t, r) for t, r in ts if t < 99] + best_t, best_rot = sorted(ts, key=lambda x: x[0], reverse=True)[0] + except IndexError: + return (0, w), (0, h), False + if best_t < min_img_crop: + return (0, w), (0, h), False + + if best_rot == 0: + x1, x2, y1, y2 = 0, best_t, 0, best_t + elif best_rot == 1: + x1, x2, y1, y2 = 0, best_t, h - best_t, h + elif best_rot == 2: + x1, x2, y1, y2 = w - best_t, w, h - best_t, h + elif best_rot == 3: + x1, x2, y1, y2 = w - best_t, w, 0, best_t + else: + raise ValueError + return (x1, x2), (y1, y2), True diff --git a/main.py b/main.py index c6040d6..e64fffc 100644 --- a/main.py +++ b/main.py @@ -30,7 +30,7 @@ EXPERIMENT_DIR = '/output/experiment' CLONE_EXPERIMENT_DIR_FROM = '' # When running eval in the cloud specify this as for example /input/SAL-14/output/experiment -OVERWRITE_EXPERIMENT_DIR = True +OVERWRITE_EXPERIMENT_DIR = False DEV_MODE = False if OVERWRITE_EXPERIMENT_DIR and os.path.isdir(EXPERIMENT_DIR): @@ -198,8 +198,6 @@ 'nr_outputs': PARAMS.nr_unet_outputs, 'encoder': PARAMS.encoder, 'activation': PARAMS.unet_activation, - 'dice_weight': PARAMS.dice_weight, - 'bce_weight': PARAMS.bce_weight, }, 'optimizer_params': {'lr': PARAMS.lr, }, @@ -219,8 +217,13 @@ 'epoch_every': 1, 'metric_name': PARAMS.validation_metric_name, 'minimize': PARAMS.minimize_validation_metric}, - 'lr_scheduler': {'gamma': PARAMS.gamma, - 'epoch_every': 1}, + 'exponential_lr_scheduler': {'gamma': PARAMS.gamma, + 'epoch_every': 1}, + 'reduce_lr_on_plateau_scheduler': {'metric_name': PARAMS.validation_metric_name, + 'minimize': PARAMS.minimize_validation_metric, + 'reduce_factor': PARAMS.reduce_factor, + 'reduce_patience': PARAMS.reduce_patience, + 'min_lr': PARAMS.min_lr}, 'training_monitor': {'batch_every': 0, 'epoch_every': 1}, 'experiment_timing': {'batch_every': 0, @@ -768,5 +771,4 @@ def save_predictions(train_ids, train_predictions, meta_test, out_of_fold_test_p if __name__ == '__main__': prepare_metadata() - train_evaluate_predict_cv() - + train_evaluate_predict_cv() \ No newline at end of file diff --git a/neptune.yaml b/neptune.yaml index 5f63d39..aa377fa 100644 --- a/neptune.yaml +++ b/neptune.yaml @@ -63,8 +63,6 @@ parameters: repeat_blocks: 4 # Loss - dice_weight: 0.0 - bce_weight: 1.0 # Training schedule epochs_nr: 10000 @@ -72,11 +70,19 @@ parameters: batch_size_inference: 64 lr: 0.0001 momentum: 0.9 - gamma: 0.95 patience: 20 validation_metric_name: 'iout' minimize_validation_metric: 0 + # Exponential LR scheduler + gamma: 0.95 + + # Reduce LR on plateau + reduce_factor: 0.1 + reduce_patience: 10 + min_lr: 1e-7 + + # Regularization use_batch_norm: 1 l2_reg_conv: 0.0001 diff --git a/prediction_average.ipynb b/prediction_average.ipynb new file mode 100644 index 0000000..7d81b3f --- /dev/null +++ b/prediction_average.ipynb @@ -0,0 +1,195 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from PIL import Image\n", + "from sklearn.externals import joblib\n", + "from tqdm import tqdm_notebook as tqdm\n", + "\n", + "from common_blocks.utils import run_length_encoding\n", + "from common_blocks.metrics import compute_eval_metric\n", + "\n", + "\n", + "METADATA_FILEPATH = 'YOUR/metadata.csv'\n", + "OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH = 'YOUR/out_of_fold_train_predictions.pkl'\n", + "\n", + "METADATA_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/files/metadata.csv'\n", + "MODEL_DIRPATH = '/mnt/ml-team/minerva/open-solutions/salt/kuba/experiments'\n", + "EXPERIMENTS = ['sal_1036_cv_829_lb_837',\n", + " 'sal_1038_cv_829_lb_834',\n", + "# 'sal_1056_cv_817_lb_825',\n", + " 'sal_1070_cv_831_lb_837',\n", + "# 'sal_1071_cv_815_lb_822',\n", + " 'sal_1078_cv_827_lb_836',\n", + "# 'sal_1085_cv_819_lb_825',\n", + "# 'sal_984_cv_819_lb_824',\n", + " 'sal_986_cv_821_lb_827',\n", + "# 'sal_989_cv_809_lb_819',\n", + "# 'sal_991_cv_810_lb_814'\n", + " ]\n", + "OUT_OF_FOLD_TRAIN_PREDICTIONS = ['{}/{}/out_of_fold_train_predictions.pkl'.format(MODEL_DIRPATH, experiment)\n", + " for experiment in EXPERIMENTS]\n", + "\n", + "OUT_OF_FOLD_TEST_PREDICTIONS = ['{}/{}/out_of_fold_test_predictions.pkl'.format(MODEL_DIRPATH, experiment)\n", + " for experiment in EXPERIMENTS]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def load_img(path):\n", + " img = np.array(Image.open(path))\n", + " return img" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metadata = pd.read_csv(METADATA_FILEPATH)\n", + "metadata.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Average out of fold predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "oof_train = joblib.load(OUT_OF_FOLD_TRAIN_PREDICTIONS[0])\n", + "\n", + "mean_train_predictions = {idx:np.zeros((101,101)) for idx in oof_train['ids']}\n", + "\n", + "for filepath in tqdm(OUT_OF_FOLD_TRAIN_PREDICTIONS):\n", + " oof_train = joblib.load(filepath)\n", + " ids, images = oof_train['ids'], oof_train['images']\n", + " for idx, image in zip(ids, images):\n", + " mask = image[1,:,:]\n", + " mean_train_predictions[idx]+=mask\n", + "\n", + "mean_train_predictions = {idx:1.0 * m/len(OUT_OF_FOLD_TRAIN_PREDICTIONS) \n", + " for idx, m in mean_train_predictions.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "THRESHOLD = 0.5 \n", + "iouts = []\n", + "for image_id, prediction_map in tqdm(mean_train_predictions.items()):\n", + " mask = (prediction_map > THRESHOLD).astype(np.uint8)\n", + " ground_truth = load_img(metadata[metadata['id']==image_id]['file_path_mask'].values[0])\n", + " ground_truth = (ground_truth > 0).astype(np.uint8)\n", + " iout = compute_eval_metric(ground_truth, mask)\n", + " iouts.append(iout)\n", + "print('IOUT {}'.format(np.mean(iouts)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Average test predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "oof_test = joblib.load(OUT_OF_FOLD_TEST_PREDICTIONS[0])\n", + "\n", + "mean_test_predictions = {idx:np.zeros((101,101)) for idx in oof_test['ids']}\n", + "\n", + "for filepath in tqdm(OUT_OF_FOLD_TEST_PREDICTIONS):\n", + " oof_test = joblib.load(filepath)\n", + " ids, images = oof_test['ids'], oof_test['images']\n", + " for idx, image in zip(ids, images):\n", + " mask = image[1,:,:]\n", + " mean_test_predictions[idx]+=mask\n", + "\n", + "mean_test_predictions = {idx:1.0 * m/len(OUT_OF_FOLD_TEST_PREDICTIONS) \n", + " for idx, m in mean_test_predictions.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "THRESHOLD = 0.5 \n", + "output = []\n", + "for image_id, prediction_map in tqdm(mean_test_predictions.items()):\n", + " mask = (prediction_map > THRESHOLD).astype(np.uint8)\n", + " rle_encoded = ' '.join(str(rle) for rle in run_length_encoding(mask))\n", + " output.append([image_id, rle_encoded])\n", + "\n", + "submission = pd.DataFrame(output, columns=['id', 'rle_mask']).astype(str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "submission.to_csv(os.path.join(MODEL_DIRPATH, 'prediction_average_cv_838_lb_xxx.csv'),index=None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dl_py3", + "language": "python", + "name": "dl_py3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/result_exploration.ipynb b/result_exploration.ipynb index 702b18e..bc27805 100644 --- a/result_exploration.ipynb +++ b/result_exploration.ipynb @@ -26,7 +26,7 @@ "METADATA_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/files/metadata.csv'\n", "\n", "OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH = 'YOUR/validation_results.pkl'\n", - "OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/kuba/experiments/sal_986_cv_821_lb_827/out_of_fold_train_predictions.pkl'" + "OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/files/out_of_fold_predictions/sal_1414_cv_829_lb_839/out_of_fold_train_predictions.pkl'" ] }, { @@ -37,7 +37,23 @@ "source": [ "def load_img(path):\n", " img = np.array(Image.open(path))\n", - " return img" + " return img\n", + "\n", + "def filter_iout(results, iout_range):\n", + " iout_min, iout_max = iout_range\n", + " results_filtered = []\n", + " for tup in results:\n", + " if iout_min<=tup[0]<=iout_max:\n", + " results_filtered.append(tup)\n", + " return results_filtered\n", + "\n", + "def filter_size(results, size_range):\n", + " size_min, size_max = size_range\n", + " results_filtered = []\n", + " for tup in results:\n", + " if size_min<=tup[1]<=size_max:\n", + " results_filtered.append(tup)\n", + " return results_filtered" ] }, { @@ -49,8 +65,8 @@ "metadata = pd.read_csv(METADATA_FILEPATH)\n", "\n", "oof_train = joblib.load(OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH)\n", - "ids = oof_train['ids']#[:100]\n", - "predictions = oof_train['images']#[:100]" + "ids = oof_train['ids']\n", + "predictions = oof_train['images']" ] }, { @@ -79,70 +95,6 @@ " sizes.append(size)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "@ipy.interact(idx = ipy.IntSlider(min=0,max=4000,value=0,step=1))\n", - "def present(idx=idx):\n", - " predicted_map = predicted_maps[idx]\n", - " predicted_mask = predicted_masks[idx]\n", - " image=images[idx]\n", - " mask=masks[idx]\n", - " size = sizes[idx]\n", - " iout = compute_eval_metric(mask, predicted_mask)\n", - " print('IOUT {} size {} depth {}'.format(iout, size, depth))\n", - " plot_list(images=[image,predicted_map],labels=[predicted_mask, mask])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Problems idx 1-200\n", - "\n", - "\n", - " 1. Border masks are a bit to small\n", - "\n", - " **idx** 32, 117\n", - " \n", - " 2. Whole image is salt misspredicted\n", - "\n", - " **idx** 74\n", - " \n", - " 3. Border problems\n", - "\n", - " **idx** 39\n", - " \n", - " 4. One pixel predicted\n", - " \n", - " **idx** \n", - " \n", - " 5. Model Fails\n", - " \n", - " **idx** 60, 105, 114, 121, 176, 191, 196\n", - " \n", - " 6. Weak Prediction\n", - " \n", - " **idx** 25, 63, 68, 109, 139, 140, 161, 190\n", - " \n", - " \n", - "## IS THAT TRUE:\n", - " \n", - " **idx** 81" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# This will take a while on the entire dataset it is probably better to take a sample" - ] - }, { "cell_type": "code", "execution_count": null, @@ -157,58 +109,29 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Score distributions" + "# Score by size" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": true + "scrolled": false }, "outputs": [], "source": [ - "sns.distplot(iouts)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "It seems that the model is either really good or really bad.\n", - "For example:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(np.mean([score for score in iouts if score>0.1]))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# score by depth" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sns.regplot(depths, iouts, fit_reg=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# score by size" + "print('IOUT {:.4f}\\n'.format(np.mean(list(zip(*results))[0])))\n", + "for size_range in [(0,0),(1,300),(300,1000),(1000,3000),(3000,9000), (9000,10201)]:\n", + " results_by_size = filter_size(results, size_range)\n", + " iout = np.mean(list(zip(*results_by_size))[0])\n", + " sample_size = len(results_by_size)\n", + " fraction = len(results_by_size)/len(results)\n", + " print('size {} | IOUT {:.4f} | sample nr {} | fraction {} | max gain {:.4f}'.format(size_range, \n", + " iout,\n", + " sample_size, \n", + " fraction,\n", + " (1.0-iout) * fraction\n", + " ))" ] }, { @@ -217,15 +140,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.regplot(sizes, iouts, fit_reg=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "It seems that the model is better for lower depths.\n", - "Not sure yet what to do with it." + "IOUT 0.8288\n", + "\n", + "size (0, 0) | IOUT 0.9539 | sample nr 1562 | fraction 0.3905 | max gain 0.0180\n", + "size (1, 300) | IOUT 0.2588 | sample nr 311 | fraction 0.07775 | max gain 0.0576\n", + "size (300, 1000) | IOUT 0.5446 | sample nr 260 | fraction 0.065 | max gain 0.0296\n", + "size (1000, 3000) | IOUT 0.7626 | sample nr 508 | fraction 0.127 | max gain 0.0301\n", + "size (3000, 9000) | IOUT 0.8994 | sample nr 1090 | fraction 0.2725 | max gain 0.0274\n", + "size (9000, 10201) | IOUT 0.8732 | sample nr 272 | fraction 0.068 | max gain 0.0086" ] }, { @@ -235,29 +157,6 @@ "# Predicted mask exploration" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def filter_iout(results, iout_range):\n", - " iout_min, iout_max = iout_range\n", - " results_filtered = []\n", - " for tup in results:\n", - " if iout_min<=tup[0]<=iout_max:\n", - " results_filtered.append(tup)\n", - " return results_filtered\n", - "\n", - "def filter_size(results, size_range):\n", - " size_min, size_max = size_range\n", - " results_filtered = []\n", - " for tup in results:\n", - " if size_min<=tup[1]<=size_max:\n", - " results_filtered.append(tup)\n", - " return results_filtered" - ] - }, { "cell_type": "code", "execution_count": null, @@ -266,12 +165,17 @@ }, "outputs": [], "source": [ - "IMG_NR = 10\n", - "# results_filtered = filter_iout(results, iout_range=(0.0,0.2))\n", - "results_filtered = filter_size(results, size_range=(1000, 100000))\n", - "print(len(results_filtered))\n", + "results_filtered = results.copy()\n", + "results_filtered = filter_iout(results_filtered, iout_range=(0.0,0.2))\n", + "results_filtered = filter_size(results_filtered, size_range=(1,300))\n", + "\n", + "print('sample nr {} fraction {} mean IOUT {}'.format(len(results_filtered), \n", + " len(results_filtered)/len(results),\n", + " np.mean(list(zip(*results_filtered))[0])))\n", "\n", - "for iout, s, z, img, pred_mask, pred_map, gt in results_filtered[:IMG_NR]:\n", + "@ipy.interact(idx = ipy.IntSlider(min=0,max=len(results_filtered)-1,value=0,step=1))\n", + "def present(idx=idx):\n", + " iout, s, z, img, pred_mask, pred_map, gt = results_filtered[idx]\n", " print('IOUT {}, size {}, depth {}'.format(iout, s, z))\n", " plot_list(images=[img, pred_map],labels=[pred_mask, gt])" ] diff --git a/small_mask_generation.ipynb b/small_mask_generation.ipynb new file mode 100644 index 0000000..0cd3ec6 --- /dev/null +++ b/small_mask_generation.ipynb @@ -0,0 +1,187 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "\n", + "import os\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.externals import joblib\n", + "from tqdm import tqdm_notebook as tqdm\n", + "import ipywidgets as ipy\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from common_blocks.utils import plot_list, load_image, save_image, resize_image, get_cut_coordinates\n", + "\n", + "METADATA_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/files/metadata.csv'\n", + "IMG_DIR = '/mnt/ml-team/minerva/open-solutions/salt/files/auxiliary_data'\n", + "IMG_DIR_MASKS =os.path.join(IMG_DIR,'masks')\n", + "AUXILIARY_METADATA_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/files/auxiliary_metadata.csv'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metadata = pd.read_csv(METADATA_FILEPATH)\n", + "metadata_train = metadata[metadata['is_train']==1]\n", + "metadata_train.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Filter larger masks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sizes = []\n", + "for file_path in tqdm(metadata_train.file_path_mask):\n", + " mask = load_image(file_path, is_mask=True)\n", + " sizes.append(mask.sum())\n", + " \n", + "metadata_train['size'] = sizes\n", + "\n", + "metadata_large_masks = metadata_train[metadata_train['size'].between(300,8000)]\n", + "metadata_large_masks.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Explore cut results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@ipy.interact(idx=ipy.IntSlider(min=0,max=4000,value=0,step=1))\n", + "def present_cut(idx):\n", + " row = metadata_large_masks.iloc[idx]\n", + " image = load_image(row.file_path_image, is_mask=False)\n", + " mask = load_image(row.file_path_mask, is_mask=True)\n", + " (x1,x2),(y1,y2), was_cropped = get_cut_coordinates(mask,step=4, min_size=50, max_size=300)\n", + " if was_cropped:\n", + " synthetic_mask = resize_image(mask[x1:x2,y1:y2], (101,101),is_mask=True)\n", + " synthetic_image = resize_image(image[x1:x2,y1:y2], (101,101))\n", + " plot_list(images=[image, synthetic_image], labels=[mask, synthetic_mask])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prepare synthetic data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "file_path_images,file_path_masks,ids, zs = [],[],[],[]\n", + "for _, row in tqdm(metadata_large_masks.iterrows()):\n", + " image = load_image(row.file_path_image, is_mask=False)\n", + " mask = load_image(row.file_path_mask, is_mask=True)\n", + " (x1,x2),(y1,y2), was_cropped = get_cut_coordinates(mask,step=4, min_size=50, max_size=300)\n", + " if was_cropped:\n", + " synthetic_mask = resize_image(mask[x1:x2,y1:y2], (101,101),is_mask=True).astype(np.uint8)\n", + " synthetic_image = (resize_image(image[x1:x2,y1:y2], (101,101))*255.).astype(np.uint8)\n", + " idx = row.id\n", + " \n", + " file_path_image=os.path.join(IMG_DIR,'images','{}.png'.format(idx))\n", + " file_path_mask=os.path.join(IMG_DIR,'masks','{}.png'.format(idx))\n", + " save_image(synthetic_image, file_path_image)\n", + " save_image(synthetic_mask, file_path_mask)\n", + " test=load_image(file_path_mask)\n", + " \n", + " file_path_images.append(file_path_image)\n", + " file_path_masks.append(file_path_mask)\n", + " ids.append(idx)\n", + " zs.append(row.z)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metadata_small_masks = pd.DataFrame({'file_path_image':file_path_images,\n", + " 'file_path_mask':file_path_masks,\n", + " 'id':ids,\n", + " 'z':zs\n", + " })\n", + "metadata_small_masks['is_train']=1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "display(metadata_small_masks.shape)\n", + "metadata_small_masks.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metadata_small_masks.to_csv(AUXILIARY_METADATA_FILEPATH)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dl_py3", + "language": "python", + "name": "dl_py3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/weighted_focal_loss.ipynb b/weighted_focal_loss.ipynb new file mode 100644 index 0000000..2ac99e8 --- /dev/null +++ b/weighted_focal_loss.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES']=''\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import glob\n", + "import cv2\n", + "from PIL import Image\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.externals import joblib\n", + "from skimage.transform import resize\n", + "from tqdm import tqdm_notebook as tqdm\n", + "from torch.autograd import Variable\n", + "import torch\n", + "import ipywidgets as ipy\n", + "\n", + "from common_blocks.augmentation import resize_pad_seq\n", + "from common_blocks.utils import plot_list, read_images\n", + "from common_blocks.models import weighted_focal_loss\n", + "from common_blocks.metrics import compute_eval_metric\n", + "\n", + "METADATA_FILEPATH = 'YOUR/metadata.csv'\n", + "OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH = 'YOUR/validation_results.pkl'\n", + "\n", + "METADATA_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/files/metadata.csv'\n", + "OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/kuba/experiments/sal_1036_cv_829_lb_837/out_of_fold_train_predictions.pkl'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def load_img(path):\n", + " img = np.array(Image.open(path))\n", + " return img\n", + "\n", + "def filter_size(sizes, size_range):\n", + " size_min, size_max = size_range\n", + " filtered_idx = []\n", + " for idx, tup in enumerate(sizes):\n", + " if size_min<=tup<=size_max:\n", + " filtered_idx.append(idx)\n", + " return filtered_idx\n", + "\n", + "image_prep = resize_pad_seq(102, 'edge', 13)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metadata = pd.read_csv(METADATA_FILEPATH)\n", + "\n", + "oof_train = joblib.load(OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH)\n", + "ids = oof_train['ids']\n", + "predictions = oof_train['images']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "THRESHOLD = 0.5\n", + "\n", + "predicted_maps, masks, images, iouts, sizes = [],[],[],[],[]\n", + "for idx, pred in tqdm(zip(ids, predictions)):\n", + " row = metadata[metadata['id']==idx]\n", + " predicted_map = np.zeros((2,101,101))\n", + " predicted_map[0,:,:] = resize(pred[0,:,:],(101,101),mode='constant')\n", + " predicted_map[1,:,:] = resize(pred[1,:,:],(101,101),mode='constant')\n", + " predicted_mask = (predicted_map[1,:,:] > THRESHOLD).astype(int)\n", + " mask = (load_img(row.file_path_mask.values[0]) > 0).astype(int)\n", + " image = load_img(row.file_path_image.values[0])\n", + " iout = compute_eval_metric(mask, predicted_mask)\n", + " size = np.sum(mask)\n", + " images.append(image)\n", + " masks.append(mask)\n", + " predicted_maps.append(predicted_map)\n", + " iouts.append(iout)\n", + " sizes.append(size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "size_idxs = filter_size(sizes, size_range=(1, 300))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@ipy.interact(idx = ipy.IntSlider(min=0,max=len(size_idxs)-1,value=0,step=1),\n", + " alpha = ipy.FloatSlider(min=0,max=1,value=1.0,step=0.05),\n", + " gamma = ipy.FloatSlider(min=0,max=10,value=0.0,step=0.1),\n", + " max_weight = ipy.FloatSlider(min=1,max=1000.0,value=100.0,step=1.0),\n", + " focus_threshold = ipy.FloatSlider(min=0,max=1,value=0.0,step=0.1),\n", + " use_size_weight = ipy.Checkbox(value=True),\n", + " use_border_weight = ipy.Checkbox(value=True),\n", + " border_size = ipy.IntSlider(min=0,max=30,value=10,step=1),\n", + " border_weight = ipy.FloatSlider(min=0,max=10.,value=10.0,step=0.25))\n", + "def present(idx, alpha, gamma,focus_threshold,\n", + " max_weight,use_size_weight, use_border_weight,border_size, border_weight):\n", + " data_idx = size_idxs[idx]\n", + " predicted_map = predicted_maps[data_idx]\n", + " logit = np.log(predicted_map/(1.0-predicted_map))\n", + " output = np.expand_dims(logit,axis=0)\n", + " \n", + " mask = masks[data_idx]\n", + "\n", + " target = np.zeros_like(output)\n", + " target[:,1,:,:] = mask\n", + " target[:,0,:,:] = (mask == 0).astype(np.uint8)\n", + "\n", + " iout = iouts[data_idx]\n", + " output = Variable(torch.Tensor(output))\n", + " target = Variable(torch.Tensor(target))\n", + " image = images[data_idx]\n", + "\n", + " focal_loss = weighted_focal_loss(output, target,\n", + " alpha=alpha, gamma=gamma,\n", + " max_weight=max_weight,\n", + " use_size_weight=use_size_weight,\n", + " use_border_weight=use_border_weight,\n", + " focus_threshold=focus_threshold,\n", + " border_size=border_size, border_weight=border_weight)\n", + " focal_loss = focal_loss.data.cpu().numpy()[0]\n", + " \n", + " bce_loss = torch.nn.BCEWithLogitsLoss()(output, target)\n", + " bce_loss = bce_loss.data.cpu().numpy()[0]\n", + " \n", + " print('BCE {:.4f}, Focal Loss {:.4f}, IOUT {:.2f}'.format(bce_loss, focal_loss, iout))\n", + " plot_list(images=[image, predicted_map[1,:,:]],labels=[mask])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dl_py3", + "language": "python", + "name": "dl_py3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}