Skip to content

Commit

Permalink
removes unneeded tf.function
Browse files Browse the repository at this point in the history
  • Loading branch information
MArpogaus committed Feb 16, 2024
1 parent 8d800f0 commit 3006ed8
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions src/bernstein_flow/math/bernstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# author : Marcel Arpogaus <marcel dot arpogaus at gmail dot com>
#
# created : 2022-03-09 08:45:52 (Marcel Arpogaus)
# changed : 2024-02-09 17:01:00 (Marcel Arpogaus)
# changed : 2024-02-16 16:43:03 (Marcel Arpogaus)
# DESCRIPTION #################################################################
# ...
# LICENSE #####################################################################
Expand Down Expand Up @@ -168,23 +168,20 @@ def gen_bernstein_polynomial_with_extrapolation(
theta
)

@tf.function
def bpoly_extra(x):
sample_shape = prefer_static.shape(x)
x_safe = (x > x_bounds[0]) & (x < x_bounds[1])
y = bpoly(tf.where(x_safe, x, tf.cast(0.5, theta.dtype)))
y = tf.where(x_safe, y, extra(x))
return reshape_output(batch_shape, sample_shape, y)

@tf.function
def bpoly_log_det_jacobian_extra(x):
sample_shape = prefer_static.shape(x)
x_safe = (x > x_bounds[0]) & (x < x_bounds[1])
y = tf.math.log(tf.abs(dbpoly(tf.where(x_safe, x, tf.cast(0.5, theta.dtype)))))
y = tf.where(x_safe, y, extra_log_det_jacobian(x))
return reshape_output(batch_shape, sample_shape, y)

@tf.function
def bpoly_inverse_extra(y, inverse_approx_fn):
sample_shape = prefer_static.shape(y)
y_safe = (y > y_bounds[0]) & (y < y_bounds[1])
Expand Down

2 comments on commit 3006ed8

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old Faithful

Learning Curve

Metrics

Min of loss: -0.7485576868057251

Parameter Vector

a1 = array([1.3379974], dtype=float32)
b1 = array([-0.19126141], dtype=float32)
thetas = array([-4.        , -2.180244  , -0.36048806, -0.3597944 , -0.35966533,
   -0.3596223 , -0.3595594 , -0.35950035, -0.35937324, -0.3558197 ,
    1.9186287 ,  2.9593098 ,  3.999991  ], dtype=float32)

Flow

Results

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bimodal Model

Learning Curve

Learning Curve

Metrics

loss: -0.7247246503829956
val_loss: -0.747040331363678

Results

Parameter Vector for x = 1

BernsteinFlow:
invert_chain_of_bpoly_of_scale1_of_shift1:
chain_of_bpoly_of_scale1_of_shift1:
bpoly: [-3.000137 -1.9482541 -0.8963711 -0.11302578 -0.02455433 -0.02454433
-0.02453433 -0.02452433 -0.02451433 -0.02450433 -0.02449433 -0.02448433
-0.02447432 -0.02446432 -0.02444159 9.331248 15.325797 15.325817
15.325837 ]
scale1: 0.3840827941894531
shift1: 0.5974458456039429

Flow



Bijector


Please sign in to comment.