Skip to content

Commit

Permalink
Support segmentation masks (#87)
Browse files Browse the repository at this point in the history
This PR introduces support for segmentation masks. The key principle is that we do not apply some operations on segmentation masks.* Introduce `SemanticWrapper` and `Mask`

Co-authored-by: Johnny Chen <[email protected]>
  • Loading branch information
barucden and johnnychen94 committed Aug 7, 2021
1 parent 3c62635 commit b5251ac
Show file tree
Hide file tree
Showing 19 changed files with 314 additions and 14 deletions.
5 changes: 4 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ makedocs(
],
"User's Guide" => [
"interface.md",
operations,
operations
],
"Developer's Guide" => [
"wrappers.md"
],
"Tutorials" => examples,
hide("Indices" => "indices.md"),
Expand Down
Binary file added docs/src/assets/segm_img.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/src/assets/segm_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
55 changes: 55 additions & 0 deletions docs/src/gettingstarted.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,61 @@ Input (`img`) | | Output (`img_new`)
:---------------------------:|:-:|:------------------------------:
![input](assets/isic_in.png) | → | ![output](assets/isic_out.gif)

## Segmentation example

Augmentor also provides a convenient interface for applying a stochastic
augmentation for images and their masks, which is useful for tasks of semantic
segmentation. The following snippet demonstrates how to use the interface. The
used image is derived from the [ISIC archive](https://isic-archive.com/).

```julia
julia> using Augmentor

julia> img, mask = # load image and its mask

julia> pl = Either(Rotate90(), FlipX(), FlipY()) |>
Either(ColorJitter(), GaussianBlur(3))

julia> img_new, mask_new = augment(img => mask, pl)
```

```@eval
using Augmentor
using FileIO, ImageMagick, ImageCore
using Random
imgpath = joinpath("assets","segm_img.png")
maskpath = joinpath("assets","segm_mask.png")
img = load(imgpath)
mask = load(maskpath)
pl = Either(Rotate90(), FlipX(), FlipY()) |>
Either(ColorJitter(), GaussianBlur(3))
# modified from operations/assets/gif.jl
function make_gif(img, mask, pl, num_sample; random_seed=1337)
Random.seed!(random_seed)
fillvalue = oneunit(eltype(img))
frames = sym_paddedviews(
fillvalue,
hcat(img, mask),
[hcat(augment(img => mask, pl)...) for _ in 1:num_sample-1]...
)
cat(frames..., dims=3)
end
preview = make_gif(img, mask, pl, 16)[:, :, 2:end]
ImageMagick.save(joinpath("assets", "segm_test.gif"), preview; fps=2)
nothing
```

The augmented images and masks are displayed in the following animation:

![output](assets/segm_test.gif)

## Getting Help

To get help on specific functionality you can either look up the
Expand Down
16 changes: 16 additions & 0 deletions docs/src/wrappers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Semantic Wrappers

Semantic wrappers are used to define a meaning of an input and
consequently, determine what operations can be applied on that input.

Each semantic wrapper is expected to implement constructor which takes
the original object and wraps it, and the [`Augmentor.unwrap`](@ref) method,
which returns the wrapped object. I.e., for a wrapper `W`, the following holds:
`obj == unwrap(W(obj))`.

To prevent name conflicts, it is suggested not to export any semantic wrappers.

```@docs
Augmentor.SemanticWrapper
Augmentor.unwrap
```
3 changes: 2 additions & 1 deletion src/Augmentor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ export
testpattern

include("compat.jl")
include("utils.jl")
include("types.jl")
include("wrapper.jl")
include("utils.jl")
include("operation.jl")

include("operations/channels.jl")
Expand Down
56 changes: 49 additions & 7 deletions src/augment.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""
augment([img], pipeline) -> out
augment(img=>mask, pipeline) -> out
Apply the operations of the given `pipeline` sequentially to the
given image `img` and return the resulting image `out`.
given image `img` and return the resulting image `out`. For the
second method, see Semantic wrappers below.
```julia-repl
julia> img = testpattern();
Expand Down Expand Up @@ -43,22 +45,62 @@ image.
```julia
augment(FlipX())
```
## Semantic wrappers
It is possible to define more flexible augmentation pipelines by wrapping the
input into a semantic wrapper. Semantic wrappers determine meaning of an input,
and ensure that only appropriate operations are applied on that input.
Currently implemented semantic wrappers are:
- [`Augmentor.Mask`](@ref): Wraps a segmentation mask. Allows only spatial
transformations.
The convenient usage for this is `augment(img => mask, pipeline)`.
### Example
```jldoctest
using Augmentor
using Augmentor: unwrap, Mask
img, mask = testpattern(), testpattern()
pl = Rotate90() |> GaussianBlur(3)
aug_img, aug_mask = unwrap.(augment((img, Mask(mask)), pl))
# Equivalent usage
aug_img, aug_mask = augment(img => mask, pl)
# GaussianBlur will be skipped for our `mask`
aug_mask == augment(mask, Rotate90())
# output
true
```
"""
function augment(img, pipeline::AbstractPipeline)
augment(img, pipeline) = _plain_augment(img, pipeline)
# convenient interpretation for certain use cases
function augment((img, mask)::Pair{<:AbstractArray, <:AbstractArray}, pipeline)
img_out, mask_out = augment((img, Mask(mask)), pipeline)
return img_out => unwrap(mask_out)
end
augment(pipeline) = augment(use_testpattern(), pipeline) # TODO: deprecate this?

# plain augment that faithfully operates on the objects without convenient interpretation
function _plain_augment(img, pipeline::AbstractPipeline)
plain_array(_augment(img, pipeline))
end

function augment(img, pipeline::Union{ImmutablePipeline{1},NTuple{1,Operation}})
function _plain_augment(img, pipeline::Union{ImmutablePipeline{1},NTuple{1,Operation}})
augment(img, first(operations(pipeline)))
end

function augment(img, op::Operation)
function _plain_augment(img, op::Operation)
plain_array(applyeager(op, img))
end

function augment(op::Union{AbstractPipeline,Operation})
augment(use_testpattern(), op)
end

@inline function _augment(img, pipeline::AbstractPipeline)
_augment(img, operations(pipeline)...)
Expand Down
15 changes: 14 additions & 1 deletion src/operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ prepareaffine(img::AbstractExtrapolation) = invwarpedview(img, toaffinemap(NoOp(
@inline prepareaffine(img::SubArray{T,N,<:InvWarpedView}) where {T,N} = img
@inline prepareaffine(img::InvWarpedView) = img
prepareaffine(imgs::Tuple) = map(prepareaffine, imgs)
function prepareaffine(sw::SemanticWrapper)
T = basetype(sw)
return T(prepareaffine(unwrap(sw)))
end

# currently unused
@inline preparelazy(img) = img
Expand All @@ -56,9 +60,18 @@ for FUN in (:applyeager, :applylazy, :applypermute,
param = randparam(op, imgs)
map(img -> ($FUN)(op, img, param), imgs)
end
@inline function ($FUN)(op::Operation, img::AbstractArray)
function ($FUN)(op::Operation, img::Union{AbstractArray,
SemanticWrapper})
($FUN)(op, img, randparam(op, img))
end

# Semantic wrapper support
@inline $FUN(op::Operation, sw::SemanticWrapper, param) = $FUN(op, sw, param, shouldapply(op, sw))
@inline $FUN(op::Operation, sw::SemanticWrapper, param, ::Val{false}) = sw
function $FUN(op::Operation, sw::SemanticWrapper, param, ::Val{true})
T = basetype(sw)
return T(($FUN)(op, unwrap(sw), param))
end
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/operations/blur.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ImageFiltering: imfilter, KernelFactors.gaussian

"""
GaussianBlur <: ImageOperation
GaussianBlur <: ColorOperation
Description
--------------
Expand Down Expand Up @@ -36,7 +36,7 @@ augment(img, GaussianBlur(3, 1.0))
augment(img, GaussianBlur(3:2:7, 1.0:0.1:2.0))
```
"""
struct GaussianBlur{K <: AbstractVector, S <: AbstractVector} <: ImageOperation
struct GaussianBlur{K <: AbstractVector, S <: AbstractVector} <: ColorOperation
k::K
σ::S

Expand Down
4 changes: 2 additions & 2 deletions src/operations/color.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import ImageCore: clamp01, gamutmax
import Statistics: mean

"""
ColorJitter <: ImageOperation
ColorJitter <: ColorOperation
Description
--------------
Expand Down Expand Up @@ -42,7 +42,7 @@ augment(img, ColorJitter(1.2, [0.5, 0.8]))
augment(img, ColorJitter(0.8:0.1:2.0, 0.5:0.1:1.1))
```
"""
struct ColorJitter{A<:AbstractVector, B<:AbstractVector} <: ImageOperation
struct ColorJitter{A<:AbstractVector, B<:AbstractVector} <: ColorOperation
α::A
β::B
usemax::Bool
Expand Down
4 changes: 4 additions & 0 deletions src/operations/either.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ end
applyview(op, img, randparam(op, img))
end

# Semantic wrapper support
applyeager(op::Either, sw::SemanticWrapper, idx) = applyeager(op.operations[idx], sw)
applyview(op::Either, sw::SemanticWrapper, idx) = applyview(op.operations[idx], sw)

# Sample a random operation and pass the function call along.
# Note: "applyaffine" needs to map to "applyaffine_common" for
# type stability, because otherwise the concrete type of the
Expand Down
1 change: 1 addition & 0 deletions src/types.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
abstract type Operation end
abstract type ImageOperation <: Operation end
abstract type AffineOperation <: ImageOperation end
abstract type ColorOperation <: ImageOperation end
abstract type Pipeline end
const AbstractPipeline = Union{Pipeline,Tuple{Vararg{Operation}}}
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ See also: [`plain_array`](@ref), [`plain_axes`](@ref)
@inline contiguous(A::SArray) = A
@inline contiguous(A::MArray) = A
@inline contiguous(A::AbstractArray) = match_idx(collect(A), axes(A))
@inline contiguous(A::Mask) = Mask(contiguous(unwrap(A)))
@inline contiguous(A::Tuple) = map(contiguous, A)

# --------------------------------------------------------------------
Expand All @@ -53,6 +54,7 @@ See also: [`plain_array`](@ref), [`plain_axes`](@ref)
@inline _plain_array(A::SArray) = A
@inline _plain_array(A::MArray) = A
@inline _plain_array(A::AbstractArray) = collect(A)
@inline _plain_array(A::Mask) = Mask(_plain_array(unwrap(A)))
@inline _plain_array(A::Tuple) = map(_plain_array, A)

"""
Expand Down Expand Up @@ -202,3 +204,8 @@ function _2dborder!(A::AbstractArray{T,3}, val::T) where T
end
A
end

# This is expected to be added to Julia (maybe under a different name)
# Follow https://github.com/JuliaLang/julia/issues/35543 for progress
basetype(T::Type) = Base.typename(T).wrapper
basetype(T) = basetype(typeof(T))
35 changes: 35 additions & 0 deletions src/wrapper.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
A SemanticWrapper determines the semantics of data that it wraps.
Any subtype needs to implement function `unwrap(wrapper)` that returns the
wrapped data.
"""
abstract type SemanticWrapper end

"""
Mask wraps a segmentation mask.
"""
struct Mask{AT<:AbstractArray} <: SemanticWrapper
img::AT
end

"""
unwrap(sw::SemanticWrapper)
Returns the original object.
"""
unwrap(m::Mask) = m.img

"""
shouldapply(op, wrapper)
shouldapply(typeof(op), typeof(wrapper))
Determines if operation `op` should be applied to semantic wrapper `wrapper`.
"""
shouldapply(op::Operation, what::SemanticWrapper) = shouldapply(typeof(op), typeof(what))
shouldapply(::Type{<:ImageOperation}, ::Type{<:SemanticWrapper}) = Val(true)
shouldapply(::Type{<:ColorOperation}, ::Type{<:Mask}) = Val(false)
# By default any operation is applicable to any semantic wrapper. Add new
# methods to this function to define exceptions.

# Allows doing `unwrap.(augment(img, Mask(img2), pl))`
unwrap(A::AbstractArray) = A
2 changes: 2 additions & 0 deletions test/operations/tst_color.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
@testset "ColorJitter" begin
@test ColorJitter <: Augmentor.ImageOperation

@testset "constructor" begin
@test_throws ArgumentError ColorJitter(1., 3.0:2.0)
@test_throws ArgumentError ColorJitter(5.0:3.0, 1.)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ rgb_rect = rand(RGB{N0f8}, 2, 3)
tests = [
"tst_compat.jl",
"tst_utils.jl",
"tst_wrapper.jl",
"operations/tst_channels.jl",
"operations/tst_dims.jl",
"operations/tst_convert.jl",
Expand Down
Loading

0 comments on commit b5251ac

Please sign in to comment.