diff --git a/src/bernstein_flow/bijectors/bernstein.py b/src/bernstein_flow/bijectors/bernstein.py index 903c383..4c2906e 100644 --- a/src/bernstein_flow/bijectors/bernstein.py +++ b/src/bernstein_flow/bijectors/bernstein.py @@ -60,7 +60,7 @@ class BernsteinBijector(tfb.AutoCompositeTensorBijector): def __init__( self, thetas: tf.Tensor, - extrapolation: str = False, + extrapolation: str = True, name: str = "bernstein_bijector", **kwds, ) -> None: