Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One-liner APIs #74

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions tf_keras_vis/activation_maximization/input_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import tensorflow as tf
from scipy.ndimage.interpolation import rotate, zoom

from ..utils import order


class InputModifier(ABC):
"""Abstract class for defining an input modifier.
Expand Down Expand Up @@ -51,12 +53,15 @@ def __call__(self, seed_input) -> np.ndarray:
class Rotate(InputModifier):
"""An input modifier that introduces random rotation.
"""
def __init__(self, axes=(1, 2), degree=3.0) -> None:
def __init__(self, axes=(1, 2), degree=3.0, interpolation='bilinear') -> None:
"""
Args:
axes: The two axes that define the plane of rotation.
Defaults to (1, 2).
degree: The amount of rotation to apply. Defaults to 3.0.
interpolation: An integer or string. When integer, `interpolation`'s specification is
the same as `order` option of scipy-ndimage API. When string, `interpolation` MUST
be one of `"nearest"`, `"bilinear"` and `"cubic"`. Defaults to `"bilinear"`.

Raises:
ValueError: When axes is not a tuple of two ints.
Expand All @@ -68,6 +73,7 @@ def __init__(self, axes=(1, 2), degree=3.0) -> None:
self.axes = axes
self.degree = float(degree)
self.random_generator = np.random.default_rng()
self.order = order(interpolation)

def __call__(self, seed_input) -> np.ndarray:
ndim = len(seed_input.shape)
Expand All @@ -80,7 +86,7 @@ def __call__(self, seed_input) -> np.ndarray:
self.random_generator.uniform(-self.degree, self.degree),
axes=self.axes,
reshape=False,
order=1,
order=self.order,
mode='reflect',
prefilter=False)
return seed_input
Expand All @@ -89,26 +95,33 @@ def __call__(self, seed_input) -> np.ndarray:
class Rotate2D(Rotate):
"""An input modifier for 2D that introduces random rotation.
"""
def __init__(self, degree=3.0) -> None:
def __init__(self, degree=3.0, interpolation='bilinear') -> None:
"""
Args:
degree: The amount of rotation to apply. Defaults to 3.0.
interpolation: An integer or string. When integer, `interpolation`'s specification is
the same as `order` option of scipy-ndimage API. When string, `interpolation` MUST
be one of `"nearest"`, `"bilinear"` and `"cubic"`. Defaults to `"bilinear"`.
"""
super().__init__(axes=(1, 2), degree=degree)
super().__init__(axes=(1, 2), degree=degree, interpolation=interpolation)


class Scale(InputModifier):
"""An input modifier that introduces randam scaling.
"""
def __init__(self, low=0.9, high=1.1) -> None:
def __init__(self, low=0.9, high=1.1, interpolation='bilinear') -> None:
"""
Args:
low (float, optional): Lower boundary of the zoom factor. Defaults to 0.9.
high (float, optional): Higher boundary of the zoom factor. Defaults to 1.1.
interpolation: An integer or string. When integer, `interpolation`'s specification is
the same as `order` option of scipy-ndimage API. When string, `interpolation` MUST
be one of `"nearest"`, `"bilinear"` and `"cubic"`. Defaults to `"bilinear"`.
"""
self.low = low
self.high = high
self.random_generator = np.random.default_rng()
self.order = order(interpolation)

def __call__(self, seed_input) -> np.ndarray:
ndim = len(seed_input.shape)
Expand All @@ -121,7 +134,7 @@ def __call__(self, seed_input) -> np.ndarray:
_factor = factor = self.random_generator.uniform(self.low, self.high)
factor *= np.ones(ndim - 2)
factor = (1, ) + tuple(factor) + (1, )
seed_input = zoom(seed_input, factor, order=1, mode='reflect', prefilter=False)
seed_input = zoom(seed_input, factor, order=self.order, mode='reflect', prefilter=False)
if _factor > 1.0:
indices = (self._central_crop_range(x, e) for x, e in zip(seed_input.shape, shape))
indices = (slice(start, stop) for start, stop in indices)
Expand Down
47 changes: 37 additions & 10 deletions tf_keras_vis/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from collections.abc import Iterable
from typing import Tuple

import numpy as np
Expand Down Expand Up @@ -39,27 +40,38 @@ def num_of_gpus() -> Tuple[int, int]:
return 0, 0


def listify(value, return_empty_list_if_none=True, convert_tuple_to_list=True) -> list:
"""Ensures that the value is a list.
def listify(value,
return_empty_list_if_none=True,
convert_tuple_to_list=True,
convert_iterable_to_list=False) -> list:
"""Ensures that `value` is a list.

If it is not a list, it creates a new list with `value` as an item.
If `value` is not a list, this function creates an new list that includes `value`.

Args:
value (object): A list or something else.
return_empty_list_if_none (bool, optional): When True (default), None you passed as `value`
will be converted to a empty list (i.e., `[]`). When False, None will be converted to
a list that has an None (i.e., `[None]`). Defaults to True.
convert_tuple_to_list (bool, optional): When True (default), a tuple you passed as `value`
will be converted to a list. When False, a tuple will be unconverted
(i.e., returning a tuple object that was passed as `value`). Defaults to True.
return_empty_list_if_none (bool, optional): When True (default), `None` you passed as
`value` will be converted to a empty list (i.e., `[]`). When False, `None` will be
converted to a list that contains an `None` (i.e., `[None]`). Defaults to True.
convert_tuple_to_list (bool, optional):When True (default), a tuple object you
passed as `value` will be converted to a list. When False, a tuple object will be
unconverted (i.e., returning a list of a tuple object). Defaults to True.
convert_iterable_to_list (bool, optional): When True (default), an iterable object you
passed as `value` will be converted to a list. When False, an iterable object will be
unconverted (i.e., returning a list of an iterable object). Defaults to False.
Returns:
list: A list. When `value` is a tuple and `convert_tuple_to_list` is False, a tuple.
list: A list
"""
if not isinstance(value, list):
if value is None and return_empty_list_if_none:
value = []
elif isinstance(value, tuple) and convert_tuple_to_list:
value = list(value)
elif isinstance(value, Iterable) and convert_iterable_to_list:
if not convert_tuple_to_list:
raise ValueError("When 'convert_tuple_to_list' option is False,"
"'convert_iterable_to_list' option should also be False.")
value = list(value)
else:
value = [value]
return value
Expand Down Expand Up @@ -132,3 +144,18 @@ def lower_precision_dtype(model):
(isinstance(layer, tf.keras.Model) and is_mixed_precision(layer)):
return layer.compute_dtype
return model.dtype # pragma: no cover


def order(value):
if isinstance(value, int):
return value
if isinstance(value, str):
value = value.lower()
if value == 'nearest':
return 0
if value == 'bilinear':
return 1
if value == 'cubic':
return 3
raise ValueError(f"{value} is not supported. "
"The value MUST be an integer or one of 'nearest', 'bilinear' or 'cubic'.")