Skip to content

Commit

Permalink
Add random fill value for GridDropout and MaskDropout
Browse files Browse the repository at this point in the history
  • Loading branch information
akarsakov committed Apr 21, 2020
1 parent c8618c9 commit 3fe268f
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 24 deletions.
47 changes: 27 additions & 20 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
import random
import typing
import warnings
from enum import IntEnum
from types import LambdaType
Expand All @@ -13,7 +14,7 @@
from . import functional as F
from .bbox_utils import denormalize_bbox, normalize_bbox, union_of_bboxes
from ..core.transforms_interface import DualTransform, ImageOnlyTransform, NoOp, to_tuple
from ..core.utils import format_args
from ..core.utils import format_args, get_random_color

__all__ = [
"Blur",
Expand Down Expand Up @@ -1574,13 +1575,8 @@ def get_params_dependent_on_targets(self, params):
holes.append((x1, y1, x2, y2))

fill_value = self.fill_value
if fill_value == "random":
ch = F.get_num_channels(img)

if img.dtype == np.uint8:
fill_value = np.random.randint(0, 256, ch, np.uint8)
else:
fill_value = np.random.uniform(0, 1, size=ch).astype(np.float32)
if self.fill_value == "random":
fill_value = get_random_color(F.get_num_channels(img), img.dtype)

return {"holes": holes, "fill_value": fill_value}

Expand Down Expand Up @@ -3119,8 +3115,9 @@ def __init__(self, max_objects=1, image_fill_value=0, mask_fill_value=0, always_
Args:
max_objects: Maximum number of labels that can be zeroed out. Can be tuple, in this case it's [min, max]
image_fill_value: Fill value to use when filling image.
Can be 'inpaint' to apply inpaining (works only for 3-chahnel images)
mask_fill_value: Fill value to use when filling mask.
Can be 'inpaint' to apply inpaining (works only for 3-chahnel images).
If image_fill_value is 'random', random color will be generated. Default = 0.
mask_fill_value: Fill value to use when filling mask. Default = 0.
Targets:
image, mask
Expand All @@ -3135,7 +3132,7 @@ def __init__(self, max_objects=1, image_fill_value=0, mask_fill_value=0, always_

@property
def targets_as_params(self):
return ["mask"]
return ["image", "mask"]

def get_params_dependent_on_targets(self, params):
mask = params["mask"]
Expand All @@ -3156,21 +3153,26 @@ def get_params_dependent_on_targets(self, params):
for label_index in labels_index:
dropout_mask |= label_image == label_index

params.update({"dropout_mask": dropout_mask})
image_fill_value = self.image_fill_value
if self.image_fill_value == "random":
img = params["image"]
image_fill_value = get_random_color(F.get_num_channels(img), img.dtype)

params.update({"dropout_mask": dropout_mask, "image_fill_value": image_fill_value})
return params

def apply(self, img, dropout_mask=None, **params):
def apply(self, img, dropout_mask=None, image_fill_value=0, **params):
if dropout_mask is None:
return img

if self.image_fill_value == "inpaint":
if image_fill_value == "inpaint":
dropout_mask = dropout_mask.astype(np.uint8)
_, _, w, h = cv2.boundingRect(dropout_mask)
radius = min(3, max(w, h) // 2)
img = cv2.inpaint(img, dropout_mask, radius, cv2.INPAINT_NS)
else:
img = img.copy()
img[dropout_mask] = self.image_fill_value
img[dropout_mask] = image_fill_value

return img

Expand Down Expand Up @@ -3262,7 +3264,8 @@ class GridDropout(DualTransform):
Clipped between 0 and grid unit height - hole_height. Default: 0.
random_offset (boolean): weather to offset the grid randomly between 0 and grid unit size - hole size
If 'True', entered shift_x, shift_y are ignored and set randomly. Default: `False`.
fill_value (int): value for the dropped pixels. Default = 0
fill_value (int, string): value for the dropped pixels.
If fill_value is 'random', random color will be generated. Default = 0
mask_fill_value (int): value for the dropped pixels in mask.
If `None`, tranformation is not applied to the mask. Default: `None`.
Targets:
Expand All @@ -3283,7 +3286,7 @@ def __init__(
shift_x: int = 0,
shift_y: int = 0,
random_offset: bool = False,
fill_value: int = 0,
fill_value: typing.Union[int, str] = 0,
mask_fill_value: int = None,
always_apply: bool = False,
p: float = 0.5,
Expand All @@ -3302,8 +3305,8 @@ def __init__(
if not 0 < self.ratio <= 1:
raise ValueError("ratio must be between 0 and 1.")

def apply(self, image, holes=(), **params):
return F.cutout(image, holes, self.fill_value)
def apply(self, image, holes=(), fill_value=0, **params):
return F.cutout(image, holes, fill_value)

def apply_to_mask(self, image, holes=(), **params):
if self.mask_fill_value is None:
Expand Down Expand Up @@ -3363,7 +3366,11 @@ def get_params_dependent_on_targets(self, params):
y2 = min(y1 + hole_height, height)
holes.append((x1, y1, x2, y2))

return {"holes": holes}
fill_value = self.fill_value
if self.fill_value == "random":
fill_value = get_random_color(F.get_num_channels(img), img.dtype)

return {"holes": holes, "fill_value": fill_value}

@property
def targets_as_params(self):
Expand Down
10 changes: 10 additions & 0 deletions albumentations/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import
from abc import ABCMeta, abstractmethod
import numpy as np

from ..core.six import string_types, add_metaclass

Expand All @@ -13,6 +14,15 @@ def format_args(args_dict):
return ", ".join(formatted_args)


def get_random_color(img_channels, dtype=np.uint8):
if dtype == np.uint8:
fill_value = np.random.randint(0, 256, img_channels, np.uint8)
else:
fill_value = np.random.uniform(0, 1, size=img_channels).astype(np.float32)

return fill_value


@add_metaclass(ABCMeta)
class Params:
def __init__(self, format, label_fields=None):
Expand Down
17 changes: 13 additions & 4 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,11 +654,20 @@ def test_gauss_noise_incorrect_var_limit_type():
assert str(exc_info.value) == message


def test_fill_value_random():
@pytest.mark.parametrize(
["augmentation_cls", "params"],
[
[A.CoarseDropout, {"fill_value": "random"}],
[A.GridDropout, {"fill_value": "random"}],
[A.MaskDropout, {"image_fill_value": "random"}],
],
)
def test_fill_value_random(augmentation_cls, params):
image = np.zeros((100, 100, 3))
aug = A.CoarseDropout(5, 10, 10, fill_value="random", always_apply=True)
mask = np.random.randint(0, 5, image.shape[:2], dtype=np.uint8)
aug = augmentation_cls(always_apply=True, **params)

augmented1 = aug(image=image)["image"]
augmented2 = aug(image=image)["image"]
augmented1 = aug(image=image, mask=mask)["image"]
augmented2 = aug(image=image, mask=mask)["image"]

assert not np.allclose(np.unique(augmented1), np.unique(augmented2))

0 comments on commit 3fe268f

Please sign in to comment.