Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A proposal of a framework to handle multi-image augmentation (Including mosaic augmentation) #1420

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

i-aki-y
Copy link
Contributor

@i-aki-y i-aki-y commented Mar 8, 2023

I propose a new batch-based augmentation framework as a natural extension of the Compose framework.

With this feature, I want to make it easy to implement augmentations that use multiple image inputs, such as Mosaic and MixUp augmentations.

I hope you will be interested and would appreciate it if you could review this PR.

About the PR

As a natural extension of the Compose, I introduced a new batch-based compose, BatchedCompose, and associate classes.
In the BatchedCompose, we can seamlessly combine single-image and multi-image transforms with minimum constraints.
The BatchCompose supports most features provided by the Compose. And the existing transforms contained in the BatchCompose, such as OneOf and HorizontalFlip, work as expected.

To demonstrate how this framework will work, I include a complete implementation of a Mosaic augmentation in this PR.
Thanks to the new framework, the implementation and usage are much simpler and cleaner than my previous PR #1147.

This is an example:

## Define augmentation
demo1 = A.BatchCompose(
    [
        A.ForEach([A.OneOf([A.VerticalFlip(p=1), A.HorizontalFlip(p=1)], p=1),A.RandomScale(scale_limit=0.2)]),
        A.Mosaic4(out_height=250, out_width=250, replace=False, value=(127, 127, 127), out_batch_size=1, p=1),
        A.ForEach([A.RandomCrop(height=200, width=200, p=1)]),
    ],
    bbox_params=A.BboxParams(format="coco", label_fields=["blabels_batch"]),
    keypoint_params=A.KeypointParams(format="xy", label_fields=["klabels_batch"]),
)

## Apply augmentation
data1 = demo1(
    image_batch=image_batch,
    bboxes_batch=bboxes_batch,
    mask_batch=mask_batch,
    keypoints_batch=kpts_batch,
    blabels_batch=blabel_batch,
    klabels_batch=klabel_batch
)

(You can see the complete code in the last section below)

The inputs and output are here:

demo1_in

demo1_out

The BatchCompose expects batched targets as inputs, so the outputs also are batched.
The ForEach is a helper container introduced in this PR. This works to bridge single-image transforms and multi-image transforms.
The Mosaic4 is an example of a multi-image transform.

Note that this example uses all standard targets (image, bboxes, mask, keypoints, and label_fields), and they work as expected.

Another demo demonstrates how powerful the framework is. (You do not need to understand the detail of the transforms).

demo2 = A.BatchCompose(
    [
        A.Repeat([
            A.Repeat([
                A.OneOf([
                    A.ForEach([
                        A.OneOf([A.VerticalFlip(p=1), A.HorizontalFlip(p=1)], p=1),
                        A.Resize(100, 100, p=1),
                    ]),
                    A.ForEach([A.RandomScale(scale_limit=(-0.5, -0.8)), A.ToSepia(p=0.5)]),
                ], p=1),
            ], n=4),
            A.Mosaic4(out_height=200, out_width=200, replace=False, value=(127, 127, 127), out_batch_size=1, p=1),
            A.ForEach([A.CoarseDropout(p=0.7, max_height=20, max_width=20), A.RandomScale(scale_limit=(-0.2, 0.0), p=1)]),
        ], n=4),
        A.ForEach([A.Rotate(limit=30, p=1)]),
        A.Mosaic4(out_height=400, out_width=400, replace=False, value=(127, 127, 127), out_batch_size=1, p=1),
        A.ForEach([
            A.RandomCrop(height=350, width=350, p=1)
        ]),        
    ]
)
image_single = skimage.data.astronaut()
data2 = demo2(image_batch=[image_single])

demo2_in

demo2_out

Note that the input is a single-image batch. The Repeat is another helper container that applies transforms n times and concatenates the output batches.
Combining the ForEach and the Repeat make this batch-based mechanism very powerful and flexible.

How to use

The user should obey the following rule to work with the BatchCompose.
I think complying with these constraints is not so hard and can cover most usecase.

  • All targets should be batch, and the names should have _batch suffixes. This rule is also applied to the label_fields and additional_targets parameters.
  • When you use the single-image transforms, you must enclose the transforms by the ForEach.
  • A multi-image transform should inherit the BatchBasedTransform.

Mosaic augmentation

This PR includes an implementation of the mosaic augmentation.
Implementation is straightforward, except that it is batch-based.
One notable feature I added is the out_batch_size parameter. This allows the user to specify the output's batch size.
Similar behavior can be achieved using a Repeat container, but each internal transform can control the behavior in more detail.
I think supporting out_batch_size can enrich the batch-base transform's flexibility.

Compatibility

The BatchCompose is implemented as a subclass of the Compose.
I have modified the Compose to add customization points but kept the behavior unchanged.

Future work

Now the BatchBasedTransform does not support the functionality of the ReplayCompose.
If this PR can be approved, I will work on this in a different PR.

Currently, Mosaic augmentation is the only example. I hope this can accelerate support for other multi-image transforms.

Full example code

Most lines are for data preparation and visualization. The only important parts are those already quoted.

import random
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np
import cv2
import albumentations as A
import skimage.data

## Helper funcs
def get_sample_data(char):
    IMAGE_SHAPE = (100, 100, 3)    
    image = np.full(IMAGE_SHAPE, 255, dtype=np.uint8)
    cv2.putText(
        image, 
        char, 
        (15, 85), 
        fontFace=cv2.FONT_HERSHEY_SIMPLEX,
        fontScale=3.5,
        color=(0, 0, 0),
        thickness=8,
        lineType=cv2.LINE_AA           
    )
    image_bin = (image < 128).astype(np.uint8)
    contours, _ = cv2.findContours(image_bin[:, :, 0], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    pts = contours[0].reshape(-1, 2)
    xmin, ymin, xmax, ymax = pts[:, 0].min(), pts[:, 1].min(), pts[:, 0].max(), pts[:, 1].max()
    bboxes = [(xmin, ymin, xmax - xmin, ymax - ymin)]    
    mask = np.full_like(image, 255, dtype=np.uint8)
    cv2.fillPoly(mask, pts.reshape(1, -1, 2), color=(255, 255, 127))
    keypoints = [(x, y) for i, (x, y) in enumerate(contours[0].reshape(-1, 2).tolist()) if i % 10 == 0]
    return image, bboxes, mask, keypoints

def get_sample_batches(chars):    
    image_batch = []
    bboxes_batch = []
    mask_batch = []
    keypoints_batch = []
    blabel_batch = []
    klabel_batch = []
    for char in chars:
        image, bboxes, mask, keypoints = get_sample_data(char)
        image_batch.append(image)
        bboxes_batch.append(bboxes)
        mask_batch.append(mask)
        keypoints_batch.append(keypoints)
        blabel_batch.append(char)
        klabel_batch.append([char for p in keypoints])
    return image_batch, bboxes_batch, mask_batch, keypoints_batch, blabel_batch, klabel_batch

COLOR_MAP = {"A": "red", "B": "blue", "C": "green", "D": "brown"}
def add_bbox(ax, bbox, label):
    x_min, y_min, w, h = bbox[:4]
    pat = Rectangle(xy=(x_min, y_min), width=w, height=h, fill=False, lw=2, color=COLOR_MAP[label], alpha=0.7)
    ax.add_patch(pat)

def visualize_image_and_annotations(image, bboxes, mask, keypoints, blabels, klabels, ax):

    alpha = 0.5
    ax.imshow((alpha * image + (1 - alpha) * mask).astype(np.uint8))
    for i in range(len(bboxes)):
        add_bbox(ax, bboxes[i], blabels[i])
    pts = np.array(keypoints).reshape(-1, 2)
    c = [COLOR_MAP[l] for l in klabels]
    ax.scatter(x=pts[:, 0], y=pts[:, 1], c=c, s=20)
    h, w = image.shape[:2]
    ax.set_xlim(0, w)
    ax.set_ylim(h, 0)


## Make Reproducible
random.seed(1)
np.random.seed(1)
    
## Prepare data
image_batch, bboxes_batch, mask_batch, kpts_batch, blabel_batch, klabel_batch = get_sample_batches("ABCD")
image_single = skimage.data.astronaut()

## Define and Apply
demo1 = A.BatchCompose(
    [
        A.ForEach([A.OneOf([A.VerticalFlip(p=1), A.HorizontalFlip(p=1)], p=1), A.RandomScale(scale_limit=0.2)]),
        A.Mosaic4(out_height=250, out_width=250, replace=False, value=(127, 127, 127), out_batch_size=1, p=1),
        A.ForEach([A.RandomCrop(height=200, width=200, p=1)]),
    ],
    bbox_params=A.BboxParams(format="coco", label_fields=["blabels_batch"]),
    keypoint_params=A.KeypointParams(format="xy", label_fields=["klabels_batch"]),
)
data1 = demo1(
    image_batch=image_batch,
    bboxes_batch=bboxes_batch,
    mask_batch=mask_batch,
    keypoints_batch=kpts_batch,
    blabels_batch=blabel_batch,
    klabels_batch=klabel_batch
)

demo2 = A.BatchCompose(
    [
        A.Repeat([
            A.Repeat([
                A.OneOf([
                    A.ForEach([
                        A.OneOf([A.VerticalFlip(p=1), A.HorizontalFlip(p=1)], p=1),
                        A.Resize(100, 100, p=1),
                    ]),
                    A.ForEach([A.RandomScale(scale_limit=(-0.5, -0.8)), A.ToSepia(p=0.5)]),
                ], p=1),
            ], n=4),
            A.Mosaic4(out_height=200, out_width=200, replace=False, value=(127, 127, 127), out_batch_size=1, p=1),
            A.ForEach([A.CoarseDropout(p=0.7, max_height=20, max_width=20), A.RandomScale(scale_limit=(-0.2, 0.0), p=1)]),
        ], n=4),
        A.ForEach([A.Rotate(limit=30, p=1)]),
        A.Mosaic4(out_height=400, out_width=400, replace=False, value=(127, 127, 127), out_batch_size=1, p=1),
        A.ForEach([
            A.RandomCrop(height=350, width=350, p=1)
        ]),        
    ]
)
data2 = demo2(image_batch=[image_single])

## Visualize

fig, axes = plt.subplots(1, 4, figsize=(10, 3))
axes = axes.flatten()
axes[0].set_title("demo1 input")
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
for i in range(4):
    visualize_image_and_annotations(
        image_batch[i], bboxes_batch[i], mask_batch[i], kpts_batch[i], blabel_batch[i], klabel_batch[i], axes[i]
    )
fig.savefig("./demo1_in.jpg", bbox_inches="tight")

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.set_title("demo1 output")
ax.set_xticks([])
ax.set_yticks([])
visualize_image_and_annotations(
    data1["image_batch"][0],
    data1["bboxes_batch"][0],
    data1["mask_batch"][0],
    data1["keypoints_batch"][0],
    data1["blabels_batch"][0],
    data1["klabels_batch"][0],        
    ax
)
fig.savefig("./demo1_out.jpg", bbox_inches="tight")

fig, ax = plt.subplots(1, 1, figsize=(3, 3))
ax.set_title("demo2 input")
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(image_single)
fig.savefig("./demo2_in.jpg", bbox_inches="tight")
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.set_xticks([])
ax.set_yticks([])
data2 = demo2(image_batch=[image_single])
ax.set_title("demo2 output")
ax.imshow(data2["image_batch"][0])
fig.savefig("./demo2_out.jpg", bbox_inches="tight")

@mikel-brostrom
Copy link

mikel-brostrom commented Mar 10, 2023

Wow, this looks fantastic @i-aki-y. Should I adapt MixUp in my PR and make it inherit BatchBasedTransform if this is the new way of working with multi-image augmentations @Dipet?

@mikel-brostrom
Copy link

Any plans of getting this merged any time soon?

@thiagoribeirodamotta
Copy link

Any plans of getting this merged any time soon?

I'm also very interested on this matter. Are there any plans on merging this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants