Skip to content

Commit

Permalink
makes it easier to change base distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
MArpogaus committed Oct 3, 2023
1 parent c1c222a commit ec74ee8
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 20 deletions.
47 changes: 44 additions & 3 deletions src/bernstein_flow/distributions/bernstein_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,39 @@ def init_bijectors(
return tfb.Invert(tfb.Chain(bijectors))


def get_base_distribution(base_distribution, dtype, **kwds):
if isinstance(base_distribution, tfd.Distribution):
return base_distribution
else:
if base_distribution == "normal":
default_kwds = dict(loc=tf.convert_to_tensor(0, dtype=dtype), scale=1.0)
default_kwds.update(**kwds)
dist = tfd.Normal(**default_kwds)
elif base_distribution == "truncated_normal":
default_kwds = dict(
loc=tf.convert_to_tensor(0, dtype=dtype), scale=1.0, low=-4, high=4
)
default_kwds.update(**kwds)
dist = tfd.TruncatedNormal(**default_kwds)
elif base_distribution == "log_normal":
default_kwds = dict(loc=tf.convert_to_tensor(0, dtype=dtype), scale=1.0)
default_kwds.update(**kwds)
dist = tfd.LogNormal(**default_kwds)
elif base_distribution == "logistic":
default_kwds = dict(loc=tf.convert_to_tensor(0, dtype=dtype), scale=1.0)
default_kwds.update(**kwds)
dist = tfd.Logistic(**default_kwds)
elif base_distribution == "uniform":
default_kwds = dict(low=tf.convert_to_tensor(0, dtype=dtype), high=1.0)
default_kwds.update(**kwds)
dist = tfd.Uniform(**default_kwds)
elif base_distribution == "kumaraswamy":
dist = tfd.Kumaraswamy(**kwds)
else:
raise ValueError(f"Unsupported distribution type {base_distribution}.")
return dist


class BernsteinFlow(tfd.TransformedDistribution):
"""
This class implements a `tfd.TransformedDistribution` using Bernstein
Expand All @@ -174,6 +207,7 @@ def __init__(
b1=None,
a2=None,
base_distribution=None,
base_distribution_kwds={},
clip_to_bernstein_domain=True,
clip_base_distribution=False,
bb_class=BernsteinBijector,
Expand All @@ -199,9 +233,16 @@ def __init__(
shape = prefer_static.shape(thetas)

if base_distribution is None:
base_distribution = tfd.Normal(
loc=tf.zeros(shape[:-1], dtype=dtype), scale=1.0
)
if bb_class == BernsteinBijector:
base_distribution = get_base_distribution(
"truncated_normal", dtype, **base_distribution_kwds
)
else:
base_distribution = get_base_distribution(
"normal", dtype, **base_distribution_kwds
)
else:
base_distribution = get_base_distribution(base_distribution, dtype)

bijector = init_bijectors(
thetas,
Expand Down
54 changes: 37 additions & 17 deletions test/distributions/test_bernstein_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def gen_dist(batch_shape, order=5, dtype=tf.float32, seed=1, **kwds):


class BernsteinFlowTest(tf.test.TestCase):
def f(self, normal_dist, trans_dist):
def f(self, normal_dist, trans_dist, stay_in_domain=False):
tf.random.set_seed(42)

dtype = normal_dist.dtype
for input_shape in [[1], [1, 1], [1] + normal_dist.batch_shape]:
x = tf.random.uniform(
shape=[10] + normal_dist.batch_shape,
minval=-100,
maxval=100,
minval=0 if stay_in_domain else -100,
maxval=1 if stay_in_domain else 100,
dtype=dtype,
)

Expand All @@ -86,12 +86,12 @@ def f(self, normal_dist, trans_dist):
# check the Normalization
self.assertAllClose(
normal_dist.cdf(normal_dist.dtype.min),
trans_dist.cdf(trans_dist.dtype.min),
trans_dist.cdf(0.0 if stay_in_domain else trans_dist.dtype.min),
atol=1e-3,
)
self.assertAllClose(
normal_dist.cdf(normal_dist.dtype.max),
trans_dist.cdf(trans_dist.dtype.max),
trans_dist.cdf(1.0 if stay_in_domain else trans_dist.dtype.max),
atol=1e-3,
)

Expand Down Expand Up @@ -153,14 +153,13 @@ def test_dist_multi_extra(self):
@pytest.mark.skip
def test_log_normal(self):
batch_shape = [16, 10]
log_normal = tfd.LogNormal(loc=tf.zeros(batch_shape, dtype=dtype), scale=1.0)
normal_dist, trans_dist = gen_dist(
batch_shape=batch_shape,
order=10,
dtype=dtype,
base_distribution=log_normal,
base_distribution="log_normal",
thetas_constrain_fn=get_thetas_constrain_fn(
low=1e-10, high=tf.math.exp(tf.constant(4.0, dtype=dtype))
low=1e-12, high=tf.math.exp(tf.constant(6.0, dtype=dtype))
),
scale_base_distribution=True,
)
Expand All @@ -169,34 +168,55 @@ def test_log_normal(self):
@pytest.mark.skip
def test_logistic(self):
batch_shape = [16, 10]
logistic = tfd.Logistic(loc=tf.zeros(batch_shape, dtype=dtype), scale=1)
normal_dist, trans_dist = gen_dist(
batch_shape=batch_shape,
order=10,
dtype=dtype,
base_distribution=logistic,
base_distribution="logistic",
bb_class=BernsteinBijectorLinearExtrapolate,
thetas_constrain_fn=get_thetas_constrain_fn(
low=-8, high=8, allow_flexible_bounds=True
low=-20, high=20, allow_flexible_bounds=True
),
scale_base_distribution=True,
scale_base_distribution=False,
clip_to_bernstein_domain=False,
)
self.f(normal_dist, trans_dist)

@pytest.mark.skip
def test_uniform(self):
batch_shape = [16, 10]
uniform = tfd.Uniform(
-tf.ones(batch_shape, dtype=dtype), tf.ones(batch_shape, dtype=dtype)
normal_dist, trans_dist = gen_dist(
batch_shape=batch_shape,
order=10,
dtype=dtype,
base_distribution="uniform",
thetas_constrain_fn=get_thetas_constrain_fn(low=0.0, high=1.0),
clip_to_bernstein_domain=False,
scale_base_distribution=False,
shift_data=False,
scale_data=False,
)
self.f(normal_dist, trans_dist, stay_in_domain=True)

@pytest.mark.skip
def test_kumaraswamy(self):
batch_shape = [16, 10]
normal_dist, trans_dist = gen_dist(
batch_shape=batch_shape,
order=10,
dtype=dtype,
base_distribution=uniform,
thetas_constrain_fn=get_thetas_constrain_fn(low=-1.0, high=1.0),
base_distribution="kumaraswamy",
base_distribution_kwds={
"concentration1": tf.convert_to_tensor(5.0, dtype),
"concentration0": 2.0,
},
thetas_constrain_fn=get_thetas_constrain_fn(low=0.0, high=1.0),
clip_to_bernstein_domain=False,
scale_base_distribution=False,
shift_data=False,
scale_data=False,
)
self.f(normal_dist, trans_dist)
self.f(normal_dist, trans_dist, stay_in_domain=True)

@pytest.mark.skip
def test_student_t(self):
Expand Down

2 comments on commit ec74ee8

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old Faithful

Learning Curve

Metrics

Min of loss: -0.7430652379989624

Parameter Vector

a1 = array([7.399017], dtype=float32)
b1 = array([-0.52731043], dtype=float32)
thetas = array([-4.        , -0.30206275, -0.29629907, -0.29544288, -0.29515794,
   -0.29501364, -0.29491523, -0.29314813, -0.26503685,  3.9999893 ],
  dtype=float32)
a2 = array([1.3188093], dtype=float32)

Results

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bimodal Model

Learning Curve

Learning Curve

Metrics

loss: -0.893079400062561
val_loss: -0.937126874923706

Results

Parameter Vector for x = 1

BernsteinFlow:
invert_chain_of_bpoly_of_scale1_of_shift1:
chain_of_bpoly_of_scale1_of_shift1:
bpoly: [-3.0000067e+00 -2.1729639e+00 -1.3459210e+00 -1.3354615e+00
-1.8176794e-02 -6.7518111e-03 -6.7372560e-03 -6.7227008e-03
-6.7081456e-03 -6.6935904e-03 -6.6790353e-03 -6.6644801e-03
-6.6499249e-03 -6.6353697e-03 -6.6208146e-03 -6.6062594e-03
-2.5596754e-03 6.3135982e+00 1.2629756e+01]
scale1: 0.47053080797195435
shift1: 0.6727997660636902

Flow



Bijector


Please sign in to comment.