Skip to content

Commit

Permalink
Feature: add bbox reflection
Browse files Browse the repository at this point in the history
Some vectorized bbox functions are also added
  • Loading branch information
i-aki-y committed Sep 1, 2022
1 parent 228c9d8 commit 8703fe9
Show file tree
Hide file tree
Showing 4 changed files with 854 additions and 4 deletions.
39 changes: 37 additions & 2 deletions albumentations/augmentations/crops/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
_maybe_process_in_chunks,
preserve_channel_dim,
)

from ...core.bbox_utils import denormalize_bbox, normalize_bbox
from ...core.bbox_utils import (
denormalize_bbox, normalize_bbox,
denormalize_bboxes2, normalize_bboxes2,
)
from ...core.transforms_interface import BoxInternalType, KeypointInternalType
from ..geometric import functional as FGeometric

__all__ = [
"get_random_crop_coords",
"random_crop",
"crop_bbox_by_coords",
"crop_bboxes_by_coords",
"bbox_random_crop",
"crop_keypoint_by_coords",
"keypoint_random_crop",
Expand Down Expand Up @@ -87,6 +90,38 @@ def crop_bbox_by_coords(
return normalize_bbox(cropped_bbox, crop_height, crop_width)


def crop_bboxes_by_coords(
bboxes: np.ndarray, crop_coords: Tuple[int, int, int, int], crop_height: int, crop_width: int, rows: int, cols: int
):
"""Crop a bounding box using the provided coordinates of bottom-left and top-right corners in pixels and the
required height and width of the crop.
Args:
bboxes (np.ndarray): A cropped box `(x_min, y_min, x_max, y_max)`.
crop_coords (tuple): Crop coordinates `(x1, y1, x2, y2)`.
crop_height (int):
crop_width (int):
rows (int): Image rows.
cols (int): Image cols.
Returns:
bboxes (np.ndarray): A cropped bounding box `(x_min, y_min, x_max, y_max)`.
"""
if not isinstance(bboxes, np.ndarray):
raise ValueError("bboxes should be np.ndarray")

bboxes = denormalize_bboxes2(bboxes, rows, cols)
bboxes = np.array(bboxes)
new_bboxes = bboxes.copy()
x1, y1, _, _ = crop_coords

new_bboxes[:, 0] = bboxes[:, 0] - x1
new_bboxes[:, 1] = bboxes[:, 1] - y1
new_bboxes[:, 2] = bboxes[:, 2] - x1
new_bboxes[:, 3] = bboxes[:, 3] - y1
new_bboxes = normalize_bboxes2(new_bboxes, crop_height, crop_width)

return new_bboxes


def bbox_random_crop(
bbox: BoxInternalType, crop_height: int, crop_width: int, h_start: float, w_start: float, rows: int, cols: int
):
Expand Down
301 changes: 300 additions & 1 deletion albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import skimage.transform
from scipy.ndimage.filters import gaussian_filter


from albumentations.augmentations.utils import (
_maybe_process_in_chunks,
angle_2pi_range,
clipped,
preserve_channel_dim,
preserve_shape,
)

from ..crops import functional as FCrop
from ... import random_utils
from ...core.bbox_utils import denormalize_bbox, normalize_bbox
from ...core.transforms_interface import (
Expand All @@ -37,6 +38,8 @@
"shift_scale_rotate",
"keypoint_shift_scale_rotate",
"bbox_shift_scale_rotate",
"bboxes_shift_scale_rotate",
"bboxes_shift_scale_rotate_reflect",
"elastic_transform",
"resize",
"scale",
Expand Down Expand Up @@ -299,6 +302,230 @@ def bbox_shift_scale_rotate(bbox, angle, scale, dx, dy, rotate_method, rows, col
return x_min, y_min, x_max, y_max


def bboxes_shift_scale_rotate(bboxes, angle, scale, dx, dy, rows, cols, **kwargs):
"""(numpy version of bbox_shift_scale_rotate)"""
if not isinstance(bboxes, np.ndarray):
raise ValueError("bboxes should be np.ndarray")
n = bboxes.shape[0]

height, width = rows, cols
center = (width / 2, height / 2)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
matrix[0, 2] += dx * width
matrix[1, 2] += dy * height

# bboxes to point_onees vector: [x, y, 1]
x1y1 = bboxes[:, [0, 1]]
x2y1 = bboxes[:, [2, 1]]
x2y2 = bboxes[:, [2, 3]]
x1y2 = bboxes[:, [0, 3]]

points_ones = np.empty((4 * n, 3), dtype=bboxes.dtype)
points_ones[0 * n : 1 * n, :2] = x1y1
points_ones[1 * n : 2 * n, :2] = x2y1
points_ones[2 * n : 3 * n, :2] = x2y2
points_ones[3 * n : 4 * n, :2] = x1y2
points_ones[:, 2] = 1

# denormalize
points_ones[:, 0] *= width
points_ones[:, 1] *= height

# transform
tr_points = matrix.dot(points_ones.T).T

# normalize
tr_points[:, 0] /= width
tr_points[:, 1] /= height

# points to bboxes
xs = tr_points.reshape((4, n, 2))[:, :, 0].T
ys = tr_points.reshape((4, n, 2))[:, :, 1].T
tr_bboxes = bboxes.copy()
tr_bboxes[:, 0] = xs.min(axis=1)
tr_bboxes[:, 1] = ys.min(axis=1)
tr_bboxes[:, 2] = xs.max(axis=1)
tr_bboxes[:, 3] = ys.max(axis=1)

return tr_bboxes


def _estimate_expand_grid_size(rows, cols, scale=1):
"""Estimate the number of grid cells to cover the reachable area (This may be overestimate)."""
cell_size = max(rows, cols)
d_x = scale * cols
d_y = scale * rows

# Estimate the distance of the point that is farthest away from the center point
# by thinking about the case of 1.0 translation and 45 degree rotation.
n_x = 1 + 2 * int(np.ceil((cell_size * (1 + np.sqrt(2) / 2) - d_x / 2) / d_x))
n_y = 1 + 2 * int(np.ceil((cell_size * (1 + np.sqrt(2) / 2) - d_y / 2) / d_y))
return n_x, n_y


def bboxes_expand_grid(bboxes, n_x, n_y, rows, cols, border_mode=cv2.BORDER_WRAP):
"""Make n_x by n_y grid from copies of bounding boxes with centering the original cell.
The border type is taken into account when copies are layedout.
Args:
bboxes (np.ndarray): A two dimensional ndarray. Each row is `x_min, y_min, x_max, y_max` or
`x_min, y_min, x_max, y_max, label_index`.
n_x (int): A number of grid cell in x-axis. Should be odd number.
n_y (int): A number of grid cell in y-axis. Should be odd number.
rows (int): Image rows.
cols (int): Image cols.
border_model (int): Border model. Should be one of:
`cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101`.
Returns:
bboxes (np.ndarray): A two dimensional ndarray. Each row is `x_min, y_min, x_max, y_max` or
`x_min, y_min, x_max, y_max, label_index`.
"""
if not isinstance(bboxes, np.ndarray):
raise ValueError("bboxes should be np.ndarray")

if n_x <= 0 or n_y <= 0:
raise ValueError(f"n_x and n_y should be non zero positive numbers. got {n_x, n_y}")

if n_x % 2 == 0 and n_y % 2 == 0:
raise ValueError(f"n_x and n_y should be odd numbers. got {n_x, n_y}")

if border_mode not in [cv2.BORDER_REFLECT, cv2.BORDER_REFLECT_101, cv2.BORDER_WRAP]:
raise ValueError(
f"Select border_mode from "
f"[cv2.BORDER_REFLECT({cv2.BORDER_REFLECT}), cv2.BORDER_WRAP({cv2.BORDER_WRAP}),"
f" cv2.BORDER_REFLECT_101({cv2.BORDER_REFLECT_101})], got {border_mode}"
)

nb, n_coord = bboxes.shape[:2]

if n_coord not in [4, 5]:
raise ValueError(f"The number of bounding box elements should be 4 or 5. got {n_coord}")

# Coordinates of the bboxes in the (i, j)-th grid cell from the top-left origin are given by bboxes + (j, i, j, i).
# Pre-compute "shift" matrix for later use. (fifth elements are dummy value that should be zero)
# if nb == 1 the shift have the following values:
#
# (0, 0, 0, 0) (1, 0, 1, 0) (2, 0, 2, 0) ...
# (0, 1, 0, 1) (1, 1, 1, 1) (2, 1, 2, 1) ...
# (0, 2, 0, 2) (1, 2, 1, 2) (2, 2, 2, 1) ...
# ...
#
shift = np.indices([n_x, n_y]).transpose(2, 1, 0)
if n_coord == 4:
shift = np.concatenate([shift, shift], axis=2)
else: # n_coord == 5
shift = np.concatenate([shift, shift, np.zeros((n_y, n_x, 1))], axis=2)

# [grid rows, grid cols, number of box, bbox coordinates]
shift = shift.reshape((n_y, n_x, 1, n_coord))

if border_mode in [cv2.BORDER_REFLECT, cv2.BORDER_REFLECT_101]:
# With n_x=3, n_y=3 and the input "bbox" desplayed by "b", the resulting output will be:
#
# q p q
# b -> d b d
# q p q
#
# where, b: original, d: h_flip, p: v_flip, q: h_flip and v_flip

# Pre-calculate all flipped bboxes patterns
flip_bbox_map = np.zeros((2, 2, nb, n_coord), dtype=type(bboxes[0][0]))
flip_bbox_map[0, 0] = np.array(bboxes).copy()
flip_bbox_map[0, 1] = bboxes_hflip(bboxes, rows, cols)
flip_bbox_map[1, 0] = bboxes_vflip(bboxes, rows, cols)
flip_bbox_map[1, 1] = bboxes_flip(bboxes, -1, rows, cols)

# The relation between flip and flags are:
#
# q p q (1, 1)(1, 0)(1, 1)
# d b d <-> (0, 1)(0, 0)(0, 1)
# q p q (1, 1)(1, 0)(1, 1)
#
flag_y0 = (n_y // 2) % 2
flag_x0 = (n_x // 2) % 2
flags_y, flags_x = np.indices((n_y, n_x)) % 2
flags_y = (flag_y0 + flags_y) % 2
flags_x = (flag_x0 + flags_x) % 2
# Expand copies of bboxes over the grid
bboxes_exp = flip_bbox_map[flags_y, flags_x] + shift

elif border_mode in [cv2.BORDER_WRAP]:
# with n=3 and the input "bbox" desplayed by "b", the resulting output will be:
#
# b b b
# b -> b b b
# b b b
#
bboxes_exp = np.repeat(np.repeat(bboxes.reshape(1, 1, nb, n_coord), n_y, axis=0), n_x, axis=1) + shift

# normalize
grid_sizes = [n_x, n_y, n_x, n_y]
if n_coord == 5:
grid_sizes += [1]
bboxes_exp = bboxes_exp.reshape(-1, n_coord) / np.array(grid_sizes).reshape(1, n_coord)

return bboxes_exp


def bboxes_shift_scale_rotate_reflect(bboxes, angle, scale, dx, dy, rows, cols, border_mode=cv2.BORDER_REFLECT_101):

if not isinstance(bboxes, np.ndarray):
raise ValueError("bboxes should be np.ndarray")

# Make n_x by n_y grid layout.
# ex. n_x = n_y = 3, border = cv2.BORDER_REFLECT
# +-+-+-+
# |q|p|q|
# +-+-+-+
# b -> |d|b|d|
# +-+-+-+
# |q|p|q|
# +-+-+-+
if border_mode in [cv2.BORDER_REFLECT_101, cv2.BORDER_WRAP, cv2.BORDER_REFLECT]:
# Estimate the number of grid cells to cover the reachable area.
# Current implementation expect that n_x and n_y are odd numbers.
n_x, n_y = _estimate_expand_grid_size(rows, cols, scale)
m_x, m_y = n_x // 2, n_y // 2 # grid index of the center cell

# Expand bboxes over the whole grid cells
bboxes_exp = bboxes_expand_grid(bboxes, n_x, n_y, rows, cols, border_mode=border_mode)
else:
n_x, n_y = 1, 1
m_x, m_y = 0, 0
bboxes_exp = bboxes

# Apply affine transform. Note that the area is magnified by n_x and n_y, the dx and dy are rescaled.
# ex. (dx, dy) = (1, 1)
# q p q
# +-+-+-+ +-+-+-+
# |q|p|q| | |d|b|d
# +-+-+-+ +-+-+-+
# |d|b|d| -> | |q|p|q
# +-+-+-+ +-+-+-+
# |q|p|q| | | | |
# +-+-+-+ +-+-+-+
tr_bboxes = bboxes_shift_scale_rotate(bboxes_exp, angle, scale, dx / n_x, dy / n_y, n_y * rows, n_x * cols)
# Crop the center cell from the grid, and remove boxes outside the cropped cell.
# q p q
# +-+-+-+
# | |d|b|d
# +-+-+-+ +-+
# | |q|p|q -> |q|
# +-+-+-+ +-+
# | | | |
# +-+-+-+
crop_coords = [m_x * cols, m_y * rows, (m_x + 1) * cols, (m_y + 1) * rows]
bboxes_crop = FCrop.crop_bboxes_by_coords(tr_bboxes, crop_coords, rows, cols, n_y * rows, n_x * cols)
# Remove bboxes that reside outside the cropped cell.
x_out = np.logical_or(bboxes_crop[:, [0, 2]].max(axis=1) <= 0, bboxes_crop[:, [0, 2]].min(axis=1) >= 1)
y_out = np.logical_or(bboxes_crop[:, [1, 3]].max(axis=1) <= 0, bboxes_crop[:, [1, 3]].min(axis=1) >= 1)
is_outside = np.logical_or(x_out, y_out)
is_inside = np.logical_not(is_outside)
bboxes_crop = bboxes_crop[is_inside, :]
return bboxes_crop


@preserve_shape
def elastic_transform(
img: np.ndarray,
Expand Down Expand Up @@ -1280,3 +1507,75 @@ def elastic_transform_approx(
borderValue=value,
)
return remap_fn(img)


def bboxes_vflip(bboxes, rows, cols): # skipcq: PYL-W0613
"""Flip a bounding box vertically around the x-axis. (numpy version of bbox_vflip)
Args:
bboxes (np.ndarray): A two dimensional ndarray. Each row is `x_min, y_min, x_max, y_max` or
`x_min, y_min, x_max, y_max, label_index`.
rows (int): Image rows.
cols (int): Image cols.
Returns:
bboxes (np.ndarray): A two dimensional ndarray. Each row is `x_min, y_min, x_max, y_max` or
`x_min, y_min, x_max, y_max, label_index`.
"""
if not isinstance(bboxes, np.ndarray):
raise ValueError("bboxes should be np.ndarray")

new_bboxes = bboxes.copy()
new_bboxes[:, 0] = bboxes[:, 0]
new_bboxes[:, 1] = 1 - bboxes[:, 3] # 1 - y_min
new_bboxes[:, 2] = bboxes[:, 2]
new_bboxes[:, 3] = 1 - bboxes[:, 1] # 1 - y_max

return new_bboxes


def bboxes_hflip(bboxes, rows, cols): # skipcq: PYL-W0613
"""Flip a bounding box horizontally around the y-axis. (numpy version of bbox_hflip)
Args:
bboxes (np.ndarray): A two dimensional ndarray. Each row is `x_min, y_min, x_max, y_max` or
`x_min, y_min, x_max, y_max, label_index`.
rows (int): Image rows.
cols (int): Image cols.
Returns:
bboxes (np.ndarray): A two dimensional ndarray. Each row is `x_min, y_min, x_max, y_max` or
`x_min, y_min, x_max, y_max, label_index`.
"""
if not isinstance(bboxes, np.ndarray):
raise ValueError("bboxes should be np.ndarray")

new_bboxes = bboxes.copy()
new_bboxes[:, 0] = 1 - bboxes[:, 2] # 1 - x_max
new_bboxes[:, 1] = bboxes[:, 1]
new_bboxes[:, 2] = 1 - bboxes[:, 0] # 1 - x_min
new_bboxes[:, 3] = bboxes[:, 3]

return new_bboxes


def bboxes_flip(bboxes, d, rows, cols):
"""Flip a bounding box either vertically, horizontally or both depending on the value of `d`.
Args:
bboxes (np.ndarray): A two dimensional ndarray. Each row is `x_min, y_min, x_max, y_max` or
`x_min, y_min, x_max, y_max, label_index`.
d (int):
rows (int): Image rows.
cols (int): Image cols.
Returns:
bboxes (np.ndarray): A two dimensional ndarray. Each row is `x_min, y_min, x_max, y_max` or
`x_min, y_min, x_max, y_max, label_index`.
Raises:
ValueError: if value of `d` is not -1, 0 or 1.
"""
if d == 0:
bboxes = bboxes_vflip(bboxes, rows, cols)
elif d == 1:
bboxes = bboxes_hflip(bboxes, rows, cols)
elif d == -1:
bboxes = bboxes_hflip(bboxes, rows, cols)
bboxes = bboxes_vflip(bboxes, rows, cols)
else:
raise ValueError("Invalid d value {}. Valid values are -1, 0 and 1".format(d))
return bboxes

0 comments on commit 8703fe9

Please sign in to comment.