Skip to content

Commit

Permalink
updates bimodal examples
Browse files Browse the repository at this point in the history
  • Loading branch information
MArpogaus committed Feb 6, 2024
1 parent 679bf63 commit 150a88f
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 55 deletions.
60 changes: 38 additions & 22 deletions cml/bimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
# author : Marcel Arpogaus <marcel dot arpogaus at gmail dot com>
#
# created : 2021-03-22 16:42:31 (Marcel Arpogaus)
# changed : 2022-08-31 17:27:02 (Marcel Arpogaus)
# changed : 2024-02-06 12:54:41 (Marcel Arpogaus)
# DESCRIPTION ############################################################
# ...
# LICENSE ################################################################
# ...
##########################################################################
# %% Imports
import argparse
import os
from functools import partial
Expand All @@ -21,18 +22,16 @@
import tensorflow as tf
import tensorflow_probability as tfp
import yaml
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Sequential
from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd

from bernstein_flow.distributions import BernsteinFlow
from bernstein_flow.util.visualization import (
plot_chained_bijectors,
plot_flow,
plot_value_and_gradient,
plot_x_trafo,
)
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Sequential
from tensorflow_probability import bijectors as tfb

try:
import mlflow
Expand All @@ -41,7 +40,12 @@
except ImportError:
USE_MLFLOW = False

# %% globals
metrics_path = "metrics/bimodal"
artifacts_path = "artifacts/bimodal"


# %% functions
def print_param(b, indent=0, prefix=""):
s = " " * indent + prefix
if not isinstance(b, tfb.Bijector):
Expand Down Expand Up @@ -114,18 +118,29 @@ def bf(y_pred):
return BernsteinFlow.from_pvector(y_pred, **kwds)

def my_loss_fn(y_true, y_pred):
return -tfd.Independent(bf(y_pred)).log_prob(tf.squeeze(y_true))
dist = bf(y_pred)
return -dist.log_prob(tf.squeeze(y_true))

flow_parameter_model.compile(
optimizer="adam",
loss=my_loss_fn
# run_eagerly=True
optimizer=tf.optimizers.Adam(0.001),
loss=my_loss_fn,
# run_eagerly=True,
)
return flow_parameter_model, bf


def fit_model(train_x, train_y, val_x, val_y, batch_size=32, epochs=1000, **kwds):
lr_patience = 15
def fit_model(
model,
bf,
train_x,
train_y,
val_x,
val_y,
batch_size,
epochs,
lr_patience=15,
**kwds,
):
callbacks = [
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.1, patience=lr_patience
Expand All @@ -135,17 +150,17 @@ def fit_model(train_x, train_y, val_x, val_y, batch_size=32, epochs=1000, **kwds
),
tf.keras.callbacks.TerminateOnNaN(),
] + kwds.pop("callbacks", [])
model, bf = gen_model(**kwds)
hist = model.fit(
train_x,
train_y,
validation_data=(val_x, val_y),
epochs=epochs,
shuffle=True,
# shuffle=True,
batch_size=batch_size,
callbacks=callbacks,
**kwds,
)
return model, bf, hist
return hist


def plot_dists(model, bf, test_x, test_t, test_y):
Expand Down Expand Up @@ -173,7 +188,6 @@ def prepare_data(n=100, scale_data_to_domain=False):
# Data
train_x, train_y = gen_train_data(n=n)
val_x, val_y = gen_train_data(n=n // 10)
train_x.shape, train_y.shape, val_x.shape, val_y.shape
test_x, test_y = gen_test_data(5, 200)

if scale_data_to_domain:
Expand Down Expand Up @@ -202,7 +216,7 @@ def results(
fig = plt.figure(figsize=(16, 8))
plt.scatter(train_x, train_y, alpha=0.5, label="train")
plt.scatter(test_t, test_y, alpha=0.5, label="test")
plt.scatter(val_x, val_y, alpha=0.5, label="test")
plt.scatter(val_x, val_y, alpha=0.5, label="validate")

plt.legend()
fig.savefig(os.path.join(artifacts_path, "bm_data.png"))
Expand Down Expand Up @@ -259,8 +273,11 @@ def run(seed, params, metrics_path, artifacts_path):
)
test_x = np.unique(test_t)

# Build Model
model, bf = gen_model(**params["model_kwds"])

# Fit Model
model, bf, hist = fit_model(train_x, train_y, val_x, val_y, **params["fit_kwds"])
hist = fit_model(model, bf, train_x, train_y, val_x, val_y, **params["fit_kwds"])

if not (
np.isnan(hist.history["loss"]).any() or np.isnan(hist.history["val_loss"]).any()
Expand All @@ -284,6 +301,7 @@ def run(seed, params, metrics_path, artifacts_path):
return model, bf, hist


# %% ifmain
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -298,15 +316,13 @@ def run(seed, params, metrics_path, artifacts_path):
parser.add_argument("--seed", help="random seed", default=1, type=int)

args = parser.parse_args()
params = yaml.load(open("cml/params.yaml"), Loader=yaml.Loader)["bimodal"]
with open("cml/params.yaml") as params_file:
params = yaml.load(params_file, Loader=yaml.Loader)["bimodal"]

# Ensure Reproducibility
print("TFP Version", tfp.__version__)
print("TF Version", tf.__version__)

metrics_path = "metrics/bimodal"
artifacts_path = "artifacts/bimodal"

if not os.path.exists(metrics_path):
os.makedirs(metrics_path)

Expand Down
53 changes: 24 additions & 29 deletions cml/hp_bimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# author : Marcel Arpogaus <marcel dot arpogaus at gmail dot com>
#
# created : 2021-05-10 17:59:17 (Marcel Arpogaus)
# changed : 2021-05-11 16:38:43 (Marcel Arpogaus)
# changed : 2024-02-05 17:31:37 (Marcel Arpogaus)
# DESCRIPTION #################################################################
# ...
# LICENSE #####################################################################
Expand All @@ -21,14 +21,10 @@
import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp
from bernstein_flow.activations import get_thetas_constrain_fn
from bimodal import run
from hyperopt import STATUS_FAIL, STATUS_OK, Trials, fmin, hp, tpe

from bernstein_flow.bijectors import (
BernsteinBijector,
BernsteinBijectorLinearExtrapolate,
)

if __name__ == "__main__":
experiment_name = "hp_bimodal"

Expand Down Expand Up @@ -60,35 +56,33 @@
if not os.path.exists(artifacts_path):
os.makedirs(artifacts_path)

common_fit_kwds = {
"output_shape": 20,
common_model_kwds = {
"scale_data": hp.choice("scale_data", [True, False]),
"shift_data": hp.choice("shift_data", [True, False]),
"scale_base_distribution": hp.choice("scale_base_distribution", [True, False]),
"clip_base_distribution": hp.choice("clip_base_distribution", [True, False]),
"allow_values_outside_support": hp.choice(
"allow_values_outside_support", [True, False]
"thetas_constrain_fn": get_thetas_constrain_fn(
low=-4, high=4, smooth_bounds=False, allow_flexible_bounds=False
),
}

space = {
"scale_data_to_domain": hp.choice("scale_data_to_domain", [True, False]),
"fit_kwds": hp.choice(
"bijector_class",
[
dict(bb_class=BernsteinBijector, **common_fit_kwds),
dict(
bb_class=BernsteinBijectorLinearExtrapolate,
clip_to_bernstein_domain=hp.choice(
"clip_to_bernstein_domain", [True, False]
),
**common_fit_kwds,
),
],
# hp.choice("scale_data_to_domain", [True, False]),
"scale_data_to_domain": False,
"model_kwds": dict(
clip_to_bernstein_domain=hp.choice(
"clip_to_bernstein_domain", [True, False]
),
scale_base_distribution=hp.choice("scale_base_distribution", [True, False]),
**common_model_kwds,
),
"fit_kwds": {
"batch_size": hp.choice("batch_size", [16, 32, 128, 512]),
"epochs": 1000,
"verbose": 0,
},
}

mlflow.autolog()
experiment_id = mlflow.set_experiment(experiment_name)
exp = mlflow.set_experiment(experiment_name)
if os.environ.get("MLFLOW_RUN_ID", False):
mlflow.start_run()
else:
Expand All @@ -100,20 +94,21 @@ def F(params):
params["data_points"] = 25

mlflow.start_run(
experiment_id=experiment_id, nested=mlflow.active_run() is not None
experiment_id=exp.experiment_id, nested=mlflow.active_run() is not None
)
mlflow.log_param("seed", args.seed)
mlflow.log_params(
dict(filter(lambda kw: not isinstance(kw[1], dict), params.items()))
)
mlflow.log_params(params["fit_kwds"])
mlflow.log_params(params["model_kwds"])
model, bf, hist = run(args.seed, params, metrics_path, artifacts_path)
mlflow.log_artifacts(artifacts_path)

loss = min(hist.history["val_loss"])
flow = bf(model(tf.linspace(0.0, 1.0, 10)[..., None]))
status = STATUS_OK
if np.isnan(loss).any() or np.isnan(flow.mean().numpy()).any():
if np.isnan(loss).any() or np.isnan(flow.sample(100)).any():
status = STATUS_FAIL
mlflow.end_run("FINISHED" if status == STATUS_OK else "FAILED")
return {"loss": loss, "status": status}
Expand All @@ -125,7 +120,7 @@ def F(params):
algo=tpe.suggest,
max_evals=2 if args._10sec else 50,
trials=trials,
rstate=np.random.RandomState(args.seed),
rstate=np.random.default_rng(args.seed),
)
mlflow.log_params(best)
mlflow.log_metric("best_score", min(trials.losses()))
Expand Down
7 changes: 5 additions & 2 deletions cml/params.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
bimodal:
fit_kwds:
batch_size: 32
epochs: 1000
# steps_per_epoch: 2
model_kwds:
output_shape: 20
bb_class: !!python/name:bernstein_flow.bijectors.bernstein_extra.BernsteinBijectorLinearExtrapolate
thetas_constrain_fn: !!python/object/apply:bernstein_flow.activations.get_thetas_constrain_fn
kwds:
low: -3
Expand All @@ -14,5 +17,5 @@ bimodal:
scale_base_distribution: false
scale_data: true
shift_data: true
batch_size: 128
extrapolation: true
scale_data_to_domain: false
4 changes: 2 additions & 2 deletions metrics/bimodal/bm_metrics.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
loss: -0.9482412934303284
val_loss: -0.9862567186355591
loss: -0.8797889947891235
val_loss: -0.9052213430404663

1 comment on commit 150a88f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bimodal Model

Learning Curve

Learning Curve

Metrics

loss: -0.8556832671165466
val_loss: -0.883833646774292

Results

Parameter Vector for x = 1

BernsteinFlow:
invert_chain_of_bpoly_of_scale1_of_shift1:
chain_of_bpoly_of_scale1_of_shift1:
bpoly: [-3.008563 -1.4516283 0.10530639 0.10531639 0.10532638 0.10533638
0.10534638 0.10535638 0.10536638 0.10537639 0.10538641 0.10540179
0.10543361 0.14035165 0.74284506 1.4303505 1.4303626 2.2151763
2.9999902 ]
scale1: -0.5063235759735107
shift1: -1.1845792531967163

Flow



Bijector


Please sign in to comment.