Skip to content

Commit

Permalink
refactor: plot flow now allows to skip certain bijectors
Browse files Browse the repository at this point in the history
  • Loading branch information
MArpogaus committed Jun 13, 2024
1 parent 1736fcd commit c473649
Showing 1 changed file with 92 additions and 49 deletions.
141 changes: 92 additions & 49 deletions src/bernstein_flow/util/visualization/plot_flow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Convenience Function to plot a normalizing flow."""
# -*- time-stamp-pattern: "changed[\s]+:[\s]+%%$"; -*-
# AUTHOR INFORMATION ##########################################################
# file : plot_flow.py
# author : Marcel Arpogaus <marcel dot arpogaus at gmail dot com>
#
# created : 2022-06-01 15:21:22 (Marcel Arpogaus)
# changed : 2024-06-13 22:08:38 (Marcel Arpogaus)
# changed : 2024-06-13 22:57:11 (Marcel Arpogaus)
# DESCRIPTION #################################################################
#
# This project is following the PEP8 style guide:
Expand Down Expand Up @@ -47,7 +48,7 @@
tf.random.set_seed(42)


def get_annot_map(bijector_names: List[str], bijector_name: str) -> Dict[str, str]:
def _get_annot_map(bijector_names: List[str], bijector_name: str) -> Dict[str, str]:
"""Get a map from bijector names to annotations.
Parameters
Expand All @@ -61,6 +62,7 @@ def get_annot_map(bijector_names: List[str], bijector_name: str) -> Dict[str, st
-------
Dict[str, str]
Dictionary mapping bijector names to annotations.
"""
annot_map = {}
for cnt, name in enumerate(bijector_names, 1):
Expand All @@ -80,7 +82,7 @@ def get_annot_map(bijector_names: List[str], bijector_name: str) -> Dict[str, st
}


def get_formulas(bijectors: List[tfb.Bijector]) -> str:
def _get_formulas(bijectors: List[tfb.Bijector]) -> str:
"""Generate LaTeX formulas for a list of bijectors.
Parameters
Expand All @@ -92,6 +94,7 @@ def get_formulas(bijectors: List[tfb.Bijector]) -> str:
-------
str
LaTeX formulas string.
"""
formulas = r"\begin{align*}"
for i, b in enumerate(reversed(bijectors)):
Expand All @@ -117,6 +120,7 @@ def _get_bijectors_recursive(bijector: tfb.Bijector) -> List[tfb.Bijector]:
-------
List[tfb.Bijector]
List of extracted bijectors.
"""
if hasattr(bijector, "bijectors"):
return sum([_get_bijectors_recursive(b) for b in bijector.bijectors], [])
Expand All @@ -126,7 +130,7 @@ def _get_bijectors_recursive(bijector: tfb.Bijector) -> List[tfb.Bijector]:
return [bijector]


def get_bijectors(
def _get_bijectors(
flow: tfp.distributions.TransformedDistribution,
) -> List[tfb.Bijector]:
"""Extract bijectors from a transformed distribution.
Expand All @@ -140,11 +144,12 @@ def get_bijectors(
-------
List[tfb.Bijector]
List of bijectors.
"""
return _get_bijectors_recursive(flow.bijector)


def get_bijector_names(bijectors: List[tfb.Bijector]) -> List[str]:
def _get_bijector_names(bijectors: List[tfb.Bijector]) -> List[str]:
"""Get the names of a list of bijectors.
Parameters
Expand All @@ -156,11 +161,12 @@ def get_bijector_names(bijectors: List[tfb.Bijector]) -> List[str]:
-------
List[str]
List of bijector names.
"""
return [b.name for b in reversed(bijectors)]


def split_bijector_names(
def _split_bijector_names(
bijector_names: List[str], split_bijector_name: str
) -> Tuple[List[str], List[str]]:
"""Split a list of bijector names at a given bijector name.
Expand All @@ -176,17 +182,19 @@ def split_bijector_names(
-------
Tuple[List[str], List[str]]
Tuple containing the two split lists.
"""
split_index = bijector_names.index(split_bijector_name) + 1
return bijector_names[:split_index], bijector_names[split_index:]


def get_plot_data(
def _get_plot_data(
flow: tfp.distributions.TransformedDistribution,
bijector_name: str,
n: int = 200,
z_values: np.ndarray = None,
seed: int = 1,
n: int,
z_values: np.ndarray,
seed: int,
ignore_bijectors: Tuple[str, ...],
) -> Tuple[Dict[str, Dict[str, np.ndarray]], List[str], List[str]]:
"""Generate plot data for a transformed distribution.
Expand All @@ -196,26 +204,31 @@ def get_plot_data(
Transformed distribution.
bijector_name : str
Name of the bijector to split the data at.
n : int, optional
Number of samples, by default 200
z_values : np.ndarray, optional
Predefined sample values, by default None
seed : int, optional
Random seed, by default 1
n : int
Number of samples
z_values : np.ndarray
Predefined sample values
seed : int
Random seed
ignore_bijectors : Tuple[str]
Tuple containing names of bijectors to ignore during plotting
Returns
-------
Tuple[Dict[str, Dict[str, np.ndarray]], List[str], List[str]]
Tuple containing the plot data, post-split bijector names,
and pre-split bijector names.
"""
tf.random.set_seed(seed)

chained_bijectors = get_bijectors(flow)
bijector_names = get_bijector_names(chained_bijectors)
pre_bpoly_trafos, post_bpoly_trafos = split_bijector_names(
bijectors = _get_bijectors(flow)
bijector_names = _get_bijector_names(bijectors)
pre_bpoly_trafos, post_bpoly_trafos = _split_bijector_names(
bijector_names, bijector_name
)
pre_bpoly_trafos = [t for t in pre_bpoly_trafos if t not in ignore_bijectors]
post_bpoly_trafos = [t for t in post_bpoly_trafos if t not in ignore_bijectors]

base_dist = flow.distribution

Expand All @@ -231,11 +244,12 @@ def get_plot_data(
z = z_sorted[..., None]
ildj = 0.0

for i, b in enumerate(chained_bijectors):
for i, b in enumerate(bijectors):
z = b.inverse(z).numpy()
ildj += b.forward_log_det_jacobian(z, 1)
name = b.name
plot_data[name] = {"z": z, "p": np.exp(log_probs + ildj)}
if name not in ignore_bijectors:
plot_data[name] = {"z": z, "p": np.exp(log_probs + ildj)}

after_bpoly = next(
(name for name in post_bpoly_trafos if name in plot_data), "distribution"
Expand All @@ -244,10 +258,10 @@ def get_plot_data(
"z1": plot_data[bijector_name]["z"],
"z2": plot_data[after_bpoly]["z"],
}
return plot_data, post_bpoly_trafos, pre_bpoly_trafos
return plot_data, pre_bpoly_trafos, post_bpoly_trafos


def configure_axes(a: Axes, style: str):
def _configure_axes(a: Axes, style: str):
"""Configure the axes of a plot.
Parameters
Expand All @@ -256,6 +270,7 @@ def configure_axes(a: Axes, style: str):
Axes object to configure.
style : str
Style of the axes. Can be "right", "top", or "none".
"""
if style == "right":
a.spines["top"].set_color("none")
Expand All @@ -276,13 +291,13 @@ def configure_axes(a: Axes, style: str):
a.patch.set_alpha(0.0)


def prepare_figure(
def _prepare_figure(
plot_data: Dict[str, Dict[str, np.ndarray]],
pre_bpoly_trafos: List[str],
post_bpoly_trafos: List[str],
size: int = 4,
wspace: float = 0.5,
hspace: float = 0.5,
size: int,
wspace: float,
hspace: float,
) -> Tuple[Figure, Dict[str, Axes]]:
"""Prepare the figure and axes for the plot.
Expand All @@ -295,16 +310,17 @@ def prepare_figure(
post_bpoly_trafos : List[str]
Post-split bijector names.
size : int, optional
Figure size, by default 4
wspace : float, optional
Width space between subplots, by default 0.5
hspace : float, optional
Height space between subplots, by default 0.5
Figure size
wspace : float
Width space between subplots
hspace : float
Height space between subplots
Returns
-------
Tuple[Figure, Dict[str, Axes]]
Tuple containing the figure and a dictionary mapping bijector names to axes.
"""
pre_bpoly = sum(k in pre_bpoly_trafos for k in plot_data)
post_bpoly = sum(k in post_bpoly_trafos for k in plot_data)
Expand All @@ -315,8 +331,8 @@ def prepare_figure(
2,
width_ratios=[pre_bpoly / 2, 1],
height_ratios=[1, (post_bpoly + 1) / 2],
wspace=0.2,
hspace=0.2,
wspace=wspace,
hspace=hspace,
)

gs00 = gs0[0, 0]
Expand All @@ -331,30 +347,30 @@ def prepare_figure(

axs: Dict[str, Axes] = {}
axs["bijector"] = fig.add_subplot(gs0[0, 1])
configure_axes(axs["bijector"], "none")
_configure_axes(axs["bijector"], "none")

idx = 0
for k in pre_bpoly_trafos + post_bpoly_trafos + ["distribution"]:
if k not in plot_data:
continue
if k in pre_bpoly_trafos:
axs[k] = fig.add_subplot(next(gs00_it), sharey=axs["bijector"])
configure_axes(axs[k], "right")
_configure_axes(axs[k], "right")
set_label = partial(axs[k].set_ylabel, rotation=0, ha="center")
elif k in post_bpoly_trafos + ["distribution"]:
axs[k] = fig.add_subplot(next(gs11_it), sharex=axs["bijector"])
configure_axes(axs[k], "top")
_configure_axes(axs[k], "top")
set_label = partial(axs[k].set_xlabel, va="center", ha="left")
set_label("$y$" if idx == 0 else f"$z_{{{idx}}}$")
idx += 1

axs["math"] = fig.add_subplot(gs0[1, 0])
configure_axes(axs["math"], "none")
_configure_axes(axs["math"], "none")

return fig, axs


def plot_data_to_axes(
def _plot_data_to_axes(
axs: Dict[str, Axes],
plot_data: Dict[str, Dict[str, np.ndarray]],
pre_bpoly_trafos: List[str],
Expand All @@ -372,6 +388,7 @@ def plot_data_to_axes(
Pre-split bijector names.
post_bpoly_trafos : List[str]
Post-split bijector names.
"""
scatter_kwds = dict(c="orange", alpha=0.2, s=8)
cpd_label = "(transformed) distribution"
Expand Down Expand Up @@ -404,7 +421,7 @@ def plot_data_to_axes(
ax.scatter(v["z2"], v["z1"], c="orange", s=4)


def add_annot_to_axes(
def _add_annot_to_axes(
axs: Dict[str, Axes],
plot_data: Dict[str, Dict[str, np.ndarray]],
pre_bpoly_trafos: List[str],
Expand Down Expand Up @@ -447,6 +464,7 @@ def add_annot_to_axes(
by default dict(arrowstyle="-|>", shrinkA=10, shrinkB=10, color="gray")
usetex : bool, optional
Whether to use LaTeX for text rendering, by default True
"""
xyA = None
axA = None
Expand Down Expand Up @@ -563,8 +581,12 @@ def plot_flow(
bijector_name: str = "bernstein_bijector",
n: int = 500,
z_values: np.ndarray = None,
seed: int = 1,
size: float = 1.5,
wspace: float = 0.5,
hspace: float = 0.5,
usetex: bool = True,
ignore_bijectors: Tuple[str, ...] = (),
**kwds,
) -> Figure:
"""Plot a transformed distribution (flow).
Expand All @@ -579,10 +601,18 @@ def plot_flow(
Number of samples, by default 500
z_values : np.ndarray, optional
Predefined sample values, by default None
seed : int, optional
Random seed, by default 1
size : float, optional
Figure size scaling factor, by default 1.5
wspace : float, optional
Width space between subplots, by default 0.5
hspace : float, optional
Height space between subplots, by default 0.5
usetex : bool, optional
Whether to use LaTeX for text rendering, by default True
ignore_bijectors : Tuple[str], optional
Tuple containing names of bijectors to ignore during plotting, by default ()
**kwds : optional
Additional keyword arguments passed to add_annot_to_axes.
Expand All @@ -595,29 +625,42 @@ def plot_flow(
------
AssertionError
If the flow is not unimodal (batch shape is not [] or [1]).
"""
if usetex:
plt.rcParams.update(
{"text.latex.preamble": r"\usepackage{amsmath}", "text.usetex": True}
)
assert flow.batch_shape in ([], [1]), "Only unimodal distributions supported"
plot_data, post_bpoly_trafos, pre_bpoly_trafos = get_plot_data(
flow, bijector_name=bijector_name, n=n, z_values=z_values
plot_data, pre_bpoly_trafos, post_bpoly_trafos = _get_plot_data(
flow,
bijector_name=bijector_name,
n=n,
z_values=z_values,
seed=seed,
ignore_bijectors=ignore_bijectors,
)
fig, axs = _prepare_figure(
plot_data,
pre_bpoly_trafos,
post_bpoly_trafos,
size=size,
wspace=wspace,
hspace=hspace,
)
fig, axs = prepare_figure(plot_data, pre_bpoly_trafos, post_bpoly_trafos, size=size)
plot_data_to_axes(axs, plot_data, pre_bpoly_trafos, post_bpoly_trafos)
bijectors = get_bijectors(flow)
bijector_names = get_bijector_names(bijectors)
_plot_data_to_axes(axs, plot_data, pre_bpoly_trafos, post_bpoly_trafos)
bijectors = _get_bijectors(flow)
bijector_names = pre_bpoly_trafos + post_bpoly_trafos
add_annot_to_axes_kwds = {
**dict(
bijector_name=bijector_name,
annot_map=get_annot_map(bijector_names, bijector_name),
formulas=get_formulas(bijectors) if usetex else None,
annot_map=_get_annot_map(bijector_names, bijector_name),
formulas=_get_formulas(bijectors) if usetex else None,
usetex=usetex,
),
**kwds,
}
add_annot_to_axes(
_add_annot_to_axes(
axs,
plot_data,
pre_bpoly_trafos,
Expand Down

0 comments on commit c473649

Please sign in to comment.