From c4736498391f84e68e8f003a5619d448c8fa8075 Mon Sep 17 00:00:00 2001 From: Marcel Arpogaus <38564291+MArpogaus@users.noreply.github.com> Date: Thu, 13 Jun 2024 23:10:11 +0200 Subject: [PATCH] refactor: plot flow now allows to skip certain bijectors --- .../util/visualization/plot_flow.py | 141 ++++++++++++------ 1 file changed, 92 insertions(+), 49 deletions(-) diff --git a/src/bernstein_flow/util/visualization/plot_flow.py b/src/bernstein_flow/util/visualization/plot_flow.py index 72be88f..f2e43e6 100644 --- a/src/bernstein_flow/util/visualization/plot_flow.py +++ b/src/bernstein_flow/util/visualization/plot_flow.py @@ -1,10 +1,11 @@ +"""Convenience Function to plot a normalizing flow.""" # -*- time-stamp-pattern: "changed[\s]+:[\s]+%%$"; -*- # AUTHOR INFORMATION ########################################################## # file : plot_flow.py # author : Marcel Arpogaus # # created : 2022-06-01 15:21:22 (Marcel Arpogaus) -# changed : 2024-06-13 22:08:38 (Marcel Arpogaus) +# changed : 2024-06-13 22:57:11 (Marcel Arpogaus) # DESCRIPTION ################################################################# # # This project is following the PEP8 style guide: @@ -47,7 +48,7 @@ tf.random.set_seed(42) -def get_annot_map(bijector_names: List[str], bijector_name: str) -> Dict[str, str]: +def _get_annot_map(bijector_names: List[str], bijector_name: str) -> Dict[str, str]: """Get a map from bijector names to annotations. Parameters @@ -61,6 +62,7 @@ def get_annot_map(bijector_names: List[str], bijector_name: str) -> Dict[str, st ------- Dict[str, str] Dictionary mapping bijector names to annotations. + """ annot_map = {} for cnt, name in enumerate(bijector_names, 1): @@ -80,7 +82,7 @@ def get_annot_map(bijector_names: List[str], bijector_name: str) -> Dict[str, st } -def get_formulas(bijectors: List[tfb.Bijector]) -> str: +def _get_formulas(bijectors: List[tfb.Bijector]) -> str: """Generate LaTeX formulas for a list of bijectors. Parameters @@ -92,6 +94,7 @@ def get_formulas(bijectors: List[tfb.Bijector]) -> str: ------- str LaTeX formulas string. + """ formulas = r"\begin{align*}" for i, b in enumerate(reversed(bijectors)): @@ -117,6 +120,7 @@ def _get_bijectors_recursive(bijector: tfb.Bijector) -> List[tfb.Bijector]: ------- List[tfb.Bijector] List of extracted bijectors. + """ if hasattr(bijector, "bijectors"): return sum([_get_bijectors_recursive(b) for b in bijector.bijectors], []) @@ -126,7 +130,7 @@ def _get_bijectors_recursive(bijector: tfb.Bijector) -> List[tfb.Bijector]: return [bijector] -def get_bijectors( +def _get_bijectors( flow: tfp.distributions.TransformedDistribution, ) -> List[tfb.Bijector]: """Extract bijectors from a transformed distribution. @@ -140,11 +144,12 @@ def get_bijectors( ------- List[tfb.Bijector] List of bijectors. + """ return _get_bijectors_recursive(flow.bijector) -def get_bijector_names(bijectors: List[tfb.Bijector]) -> List[str]: +def _get_bijector_names(bijectors: List[tfb.Bijector]) -> List[str]: """Get the names of a list of bijectors. Parameters @@ -156,11 +161,12 @@ def get_bijector_names(bijectors: List[tfb.Bijector]) -> List[str]: ------- List[str] List of bijector names. + """ return [b.name for b in reversed(bijectors)] -def split_bijector_names( +def _split_bijector_names( bijector_names: List[str], split_bijector_name: str ) -> Tuple[List[str], List[str]]: """Split a list of bijector names at a given bijector name. @@ -176,17 +182,19 @@ def split_bijector_names( ------- Tuple[List[str], List[str]] Tuple containing the two split lists. + """ split_index = bijector_names.index(split_bijector_name) + 1 return bijector_names[:split_index], bijector_names[split_index:] -def get_plot_data( +def _get_plot_data( flow: tfp.distributions.TransformedDistribution, bijector_name: str, - n: int = 200, - z_values: np.ndarray = None, - seed: int = 1, + n: int, + z_values: np.ndarray, + seed: int, + ignore_bijectors: Tuple[str, ...], ) -> Tuple[Dict[str, Dict[str, np.ndarray]], List[str], List[str]]: """Generate plot data for a transformed distribution. @@ -196,26 +204,31 @@ def get_plot_data( Transformed distribution. bijector_name : str Name of the bijector to split the data at. - n : int, optional - Number of samples, by default 200 - z_values : np.ndarray, optional - Predefined sample values, by default None - seed : int, optional - Random seed, by default 1 + n : int + Number of samples + z_values : np.ndarray + Predefined sample values + seed : int + Random seed + ignore_bijectors : Tuple[str] + Tuple containing names of bijectors to ignore during plotting Returns ------- Tuple[Dict[str, Dict[str, np.ndarray]], List[str], List[str]] Tuple containing the plot data, post-split bijector names, and pre-split bijector names. + """ tf.random.set_seed(seed) - chained_bijectors = get_bijectors(flow) - bijector_names = get_bijector_names(chained_bijectors) - pre_bpoly_trafos, post_bpoly_trafos = split_bijector_names( + bijectors = _get_bijectors(flow) + bijector_names = _get_bijector_names(bijectors) + pre_bpoly_trafos, post_bpoly_trafos = _split_bijector_names( bijector_names, bijector_name ) + pre_bpoly_trafos = [t for t in pre_bpoly_trafos if t not in ignore_bijectors] + post_bpoly_trafos = [t for t in post_bpoly_trafos if t not in ignore_bijectors] base_dist = flow.distribution @@ -231,11 +244,12 @@ def get_plot_data( z = z_sorted[..., None] ildj = 0.0 - for i, b in enumerate(chained_bijectors): + for i, b in enumerate(bijectors): z = b.inverse(z).numpy() ildj += b.forward_log_det_jacobian(z, 1) name = b.name - plot_data[name] = {"z": z, "p": np.exp(log_probs + ildj)} + if name not in ignore_bijectors: + plot_data[name] = {"z": z, "p": np.exp(log_probs + ildj)} after_bpoly = next( (name for name in post_bpoly_trafos if name in plot_data), "distribution" @@ -244,10 +258,10 @@ def get_plot_data( "z1": plot_data[bijector_name]["z"], "z2": plot_data[after_bpoly]["z"], } - return plot_data, post_bpoly_trafos, pre_bpoly_trafos + return plot_data, pre_bpoly_trafos, post_bpoly_trafos -def configure_axes(a: Axes, style: str): +def _configure_axes(a: Axes, style: str): """Configure the axes of a plot. Parameters @@ -256,6 +270,7 @@ def configure_axes(a: Axes, style: str): Axes object to configure. style : str Style of the axes. Can be "right", "top", or "none". + """ if style == "right": a.spines["top"].set_color("none") @@ -276,13 +291,13 @@ def configure_axes(a: Axes, style: str): a.patch.set_alpha(0.0) -def prepare_figure( +def _prepare_figure( plot_data: Dict[str, Dict[str, np.ndarray]], pre_bpoly_trafos: List[str], post_bpoly_trafos: List[str], - size: int = 4, - wspace: float = 0.5, - hspace: float = 0.5, + size: int, + wspace: float, + hspace: float, ) -> Tuple[Figure, Dict[str, Axes]]: """Prepare the figure and axes for the plot. @@ -295,16 +310,17 @@ def prepare_figure( post_bpoly_trafos : List[str] Post-split bijector names. size : int, optional - Figure size, by default 4 - wspace : float, optional - Width space between subplots, by default 0.5 - hspace : float, optional - Height space between subplots, by default 0.5 + Figure size + wspace : float + Width space between subplots + hspace : float + Height space between subplots Returns ------- Tuple[Figure, Dict[str, Axes]] Tuple containing the figure and a dictionary mapping bijector names to axes. + """ pre_bpoly = sum(k in pre_bpoly_trafos for k in plot_data) post_bpoly = sum(k in post_bpoly_trafos for k in plot_data) @@ -315,8 +331,8 @@ def prepare_figure( 2, width_ratios=[pre_bpoly / 2, 1], height_ratios=[1, (post_bpoly + 1) / 2], - wspace=0.2, - hspace=0.2, + wspace=wspace, + hspace=hspace, ) gs00 = gs0[0, 0] @@ -331,7 +347,7 @@ def prepare_figure( axs: Dict[str, Axes] = {} axs["bijector"] = fig.add_subplot(gs0[0, 1]) - configure_axes(axs["bijector"], "none") + _configure_axes(axs["bijector"], "none") idx = 0 for k in pre_bpoly_trafos + post_bpoly_trafos + ["distribution"]: @@ -339,22 +355,22 @@ def prepare_figure( continue if k in pre_bpoly_trafos: axs[k] = fig.add_subplot(next(gs00_it), sharey=axs["bijector"]) - configure_axes(axs[k], "right") + _configure_axes(axs[k], "right") set_label = partial(axs[k].set_ylabel, rotation=0, ha="center") elif k in post_bpoly_trafos + ["distribution"]: axs[k] = fig.add_subplot(next(gs11_it), sharex=axs["bijector"]) - configure_axes(axs[k], "top") + _configure_axes(axs[k], "top") set_label = partial(axs[k].set_xlabel, va="center", ha="left") set_label("$y$" if idx == 0 else f"$z_{{{idx}}}$") idx += 1 axs["math"] = fig.add_subplot(gs0[1, 0]) - configure_axes(axs["math"], "none") + _configure_axes(axs["math"], "none") return fig, axs -def plot_data_to_axes( +def _plot_data_to_axes( axs: Dict[str, Axes], plot_data: Dict[str, Dict[str, np.ndarray]], pre_bpoly_trafos: List[str], @@ -372,6 +388,7 @@ def plot_data_to_axes( Pre-split bijector names. post_bpoly_trafos : List[str] Post-split bijector names. + """ scatter_kwds = dict(c="orange", alpha=0.2, s=8) cpd_label = "(transformed) distribution" @@ -404,7 +421,7 @@ def plot_data_to_axes( ax.scatter(v["z2"], v["z1"], c="orange", s=4) -def add_annot_to_axes( +def _add_annot_to_axes( axs: Dict[str, Axes], plot_data: Dict[str, Dict[str, np.ndarray]], pre_bpoly_trafos: List[str], @@ -447,6 +464,7 @@ def add_annot_to_axes( by default dict(arrowstyle="-|>", shrinkA=10, shrinkB=10, color="gray") usetex : bool, optional Whether to use LaTeX for text rendering, by default True + """ xyA = None axA = None @@ -563,8 +581,12 @@ def plot_flow( bijector_name: str = "bernstein_bijector", n: int = 500, z_values: np.ndarray = None, + seed: int = 1, size: float = 1.5, + wspace: float = 0.5, + hspace: float = 0.5, usetex: bool = True, + ignore_bijectors: Tuple[str, ...] = (), **kwds, ) -> Figure: """Plot a transformed distribution (flow). @@ -579,10 +601,18 @@ def plot_flow( Number of samples, by default 500 z_values : np.ndarray, optional Predefined sample values, by default None + seed : int, optional + Random seed, by default 1 size : float, optional Figure size scaling factor, by default 1.5 + wspace : float, optional + Width space between subplots, by default 0.5 + hspace : float, optional + Height space between subplots, by default 0.5 usetex : bool, optional Whether to use LaTeX for text rendering, by default True + ignore_bijectors : Tuple[str], optional + Tuple containing names of bijectors to ignore during plotting, by default () **kwds : optional Additional keyword arguments passed to add_annot_to_axes. @@ -595,29 +625,42 @@ def plot_flow( ------ AssertionError If the flow is not unimodal (batch shape is not [] or [1]). + """ if usetex: plt.rcParams.update( {"text.latex.preamble": r"\usepackage{amsmath}", "text.usetex": True} ) assert flow.batch_shape in ([], [1]), "Only unimodal distributions supported" - plot_data, post_bpoly_trafos, pre_bpoly_trafos = get_plot_data( - flow, bijector_name=bijector_name, n=n, z_values=z_values + plot_data, pre_bpoly_trafos, post_bpoly_trafos = _get_plot_data( + flow, + bijector_name=bijector_name, + n=n, + z_values=z_values, + seed=seed, + ignore_bijectors=ignore_bijectors, + ) + fig, axs = _prepare_figure( + plot_data, + pre_bpoly_trafos, + post_bpoly_trafos, + size=size, + wspace=wspace, + hspace=hspace, ) - fig, axs = prepare_figure(plot_data, pre_bpoly_trafos, post_bpoly_trafos, size=size) - plot_data_to_axes(axs, plot_data, pre_bpoly_trafos, post_bpoly_trafos) - bijectors = get_bijectors(flow) - bijector_names = get_bijector_names(bijectors) + _plot_data_to_axes(axs, plot_data, pre_bpoly_trafos, post_bpoly_trafos) + bijectors = _get_bijectors(flow) + bijector_names = pre_bpoly_trafos + post_bpoly_trafos add_annot_to_axes_kwds = { **dict( bijector_name=bijector_name, - annot_map=get_annot_map(bijector_names, bijector_name), - formulas=get_formulas(bijectors) if usetex else None, + annot_map=_get_annot_map(bijector_names, bijector_name), + formulas=_get_formulas(bijectors) if usetex else None, usetex=usetex, ), **kwds, } - add_annot_to_axes( + _add_annot_to_axes( axs, plot_data, pre_bpoly_trafos,