Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

newton_solver halts sampling #2985

Open
DemetriPananos opened this issue Dec 11, 2023 · 14 comments
Open

newton_solver halts sampling #2985

DemetriPananos opened this issue Dec 11, 2023 · 14 comments

Comments

@DemetriPananos
Copy link

DemetriPananos commented Dec 11, 2023

Description

I've been trying to implement a simple sampler in the generated quantities block using the inverse CDF method. First, I draw a uniform random variable on (0, 1) then use newton_solver to sample an x. The method seems to work well for a few iterations, but then sampling completely halts. @rok-cesnovar has narrowed this down to cases when the uniform random variable is close to 0.

Example

library(cmdstanr)
library(tidyverse)

model_code <- "
functions {
  real lawless_generalized_gamma_cdf(real x, real k, real mu, real sigma) {
    real w = (log(x) - mu) / sigma;
    return gamma_p(k, k * exp(1 / sqrt(k) * w));
  }
  vector system(vector y, real k, real mu, real sigma, real u) {
    return [lawless_generalized_gamma_cdf(y[1]| k, mu, sigma) - u]';
  }
}
data {
  real k;
  real mu;
  real sigma;
  real u;

  real scaling_step;
  real f_tol;
  int max_steps;
}
generated quantities {
  real solution = solve_newton_tol(system, [mu]', scaling_step, f_tol, max_steps, k, mu, sigma, u)[1];
}
"

model_file <- write_stan_file(model_code)
model <- cmdstan_model(model_file)


u <- 5e-5

fit <- model$sample(
  data = list(
    k = 0.96, mu = 4.84, sigma = 0.97, u = u,
    scaling_step = 10e-3, f_tol = 10e-6, max_steps = 1
  ),
  chains = 1,
  iter_sampling = 1,
  fixed_param = TRUE
)

I've also tried trying to solve on the log scale (using the log cdf and log(u)) and the problem persists. The sampling simply halts and does not continue, no error messages.

Screen Shot 2023-12-11 at 10 34 07 AM

Current Version:

v4.7.0

@DemetriPananos
Copy link
Author

We also tried this with cmdstan directly, so we don't think this is an issue with the interface.

@DemetriPananos
Copy link
Author

It seems R's optim can't do this either -- which is fine --, but perhaps Stan should throw an error or a warning. Halting is not a great behaviour.

@rok-cesnovar
Copy link
Member

@charlesm93 would you maybe have any ideas on why the solver gets stuck in this case and doesn't error?

@spinkney
Copy link
Collaborator

@DemetriPananos do you know what the solution is supposed to be here?

@DemetriPananos
Copy link
Author

@spinkney No, I don't know the solution. I'm not so concerned that stan can't solve it -- R has problems too -- I'm more concerned that Stan just halts rather than raising a warning or something more reasonable.

@spinkney
Copy link
Collaborator

I agree. In the meantime, there's this code #2859 (comment). Would that help you?

I'm getting an answer of 0.00714.

For quatiles 0.1, 0.25, 0.5, 0.75, 0.9 I get

Quantile Y value
0.1 13.8
0.25 37.2
0.5 87.9
0.75 173
0.9 282

@DemetriPananos
Copy link
Author

We found our own work around, but thought I should raise this to the Stan team

@syclik
Copy link
Member

syclik commented Jan 5, 2024

@DemetriPananos, thanks for reporting the issue.

EDIT: My original comment was incorrect. I was able to recreate it. Please see #2985 (comment).

I was able to recreate the issue using cmdstanr, but I think it's upstream from the Math library (maybe somewhere in Stan?).

I just created a C++ unit test to test the behavior and it stops as expected.

I added this to the end of test/unit/math/rev/functor/solve_newton_test.cpp, hopefully it looks familiar:

template <typename T1, typename T2, typename T3, typename T4>
stan::return_type_t<T1, T2, T3, T4> lawless_generalized_gamma_cdf(T1& x, T2& k, T3& mu, T4& sigma) {
  stan::return_type_t<T1, T3, T4> w = (log(x) - mu) / sigma;
  return stan::math::gamma_p(k, k * exp(1 / (sqrt(k) * w)));
}

struct issue_2985_functor {

  template <typename T1, typename T2, typename T3, typename T4, typename T5>
  inline Eigen::Matrix<stan::return_type_t<T1, T2, T3, T4, T5>, Eigen::Dynamic, 1>
  operator()(const T1& y, std::ostream* pstream__,
	     const T2& k,
	     const T3& mu, const T4& sigma,
	     const T5& u) const {
    Eigen::Matrix<stan::return_type_t<T1, T2, T3, T4, T5>, Eigen::Dynamic, 1> out(1);

    out << (lawless_generalized_gamma_cdf(y[0], k, mu, sigma) - u);
    return out;
  }
};


TEST(newton_solver_test, issue_2985) {
  double k = 0.96,
    mu = 4.84,
    sigma = 0.97,
    u = u,
    scaling_step = 10e-3,
    f_tol = 10e-6;
  int max_steps = 1;

  Eigen::VectorXd x(1);
  x << mu; // initial guess

  EXPECT_THROW( {
      Eigen::MatrixXd out = stan::math::solve_newton_tol(issue_2985_functor(),
							 x,
							 scaling_step, f_tol,
							 max_steps,
							 &std::cout, k, mu, sigma, u);
    }, std::domain_error);
}

And it passes. (It successfully stops and throws an std::domain_error.)

The next step is to find out why it's not stopping on that exception.

@syclik
Copy link
Member

syclik commented Jan 5, 2024

I'm going to try to plug in what gets generated from stanc to see if there's different behavior.

@syclik
Copy link
Member

syclik commented Jan 5, 2024

The generated code, as a standalone C++ test, does end up in a loop.

Here's what the C++ now looks like. There was minor modification to remove things from the stan::model namespace (and I've added a couple prints already):

template <typename T0__, typename T1__, typename T2__, typename T3__,
          stan::require_all_t<stan::math::disjunction<stan::is_autodiff<T0__>,
                                                      std::is_floating_point<T0__>>,
                              stan::math::disjunction<stan::is_autodiff<T1__>,
                                                      std::is_floating_point<T1__>>,
                              stan::math::disjunction<stan::is_autodiff<T2__>,
                                                      std::is_floating_point<T2__>>,
                              stan::math::disjunction<stan::is_autodiff<T3__>,
                                                      std::is_floating_point<T3__>>>* = nullptr>
stan::promote_args_t<T0__, T1__, T2__, T3__>
lawless_generalized_gamma_cdf(const T0__& x, const T1__& k, const T2__& mu,
                              const T3__& sigma, std::ostream* pstream__);
template <typename T0__, typename T1__, typename T2__, typename T3__,
          typename T4__,
          stan::require_all_t<stan::is_col_vector<T0__>,
                              stan::is_vt_not_complex<T0__>,
                              stan::math::disjunction<stan::is_autodiff<T1__>,
                                                      std::is_floating_point<T1__>>,
                              stan::math::disjunction<stan::is_autodiff<T2__>,
                                                      std::is_floating_point<T2__>>,
                              stan::math::disjunction<stan::is_autodiff<T3__>,
                                                      std::is_floating_point<T3__>>,
                              stan::math::disjunction<stan::is_autodiff<T4__>,
                                                      std::is_floating_point<T4__>>>* = nullptr>
Eigen::Matrix<stan::promote_args_t<stan::base_type_t<T0__>, T1__, T2__, T3__,
                T4__>,-1,1>
system(const T0__& y_arg__, const T1__& k, const T2__& mu, const T3__& sigma,
       const T4__& u, std::ostream* pstream__);
struct system_variadic1_functor__ {
  template <typename T0__, typename T1__, typename T2__, typename T3__,
            typename T4__,
            stan::require_all_t<stan::is_col_vector<T0__>,
                                stan::is_vt_not_complex<T0__>,
                                stan::math::disjunction<stan::is_autodiff<T1__>,
                                                        std::is_floating_point<T1__>>,
                                stan::math::disjunction<stan::is_autodiff<T2__>,
                                                        std::is_floating_point<T2__>>,
                                stan::math::disjunction<stan::is_autodiff<T3__>,
                                                        std::is_floating_point<T3__>>,
                                stan::math::disjunction<stan::is_autodiff<T4__>,
                                                        std::is_floating_point<T4__>>>* = nullptr>
  Eigen::Matrix<stan::promote_args_t<stan::base_type_t<T0__>, T1__, T2__,
                  T3__, T4__>,-1,1>
  operator()(const T0__& y, std::ostream* pstream__, const T1__& k,
             const T2__& mu, const T3__& sigma, const T4__& u) const {
    return system(y, k, mu, sigma, u, pstream__);
  }
};
// real lawless_generalized_gamma_cdf(real, real, real, real)
template <typename T0__, typename T1__, typename T2__, typename T3__,
          stan::require_all_t<stan::math::disjunction<stan::is_autodiff<T0__>,
                                                      std::is_floating_point<T0__>>,
                              stan::math::disjunction<stan::is_autodiff<T1__>,
                                                      std::is_floating_point<T1__>>,
                              stan::math::disjunction<stan::is_autodiff<T2__>,
                                                      std::is_floating_point<T2__>>,
                              stan::math::disjunction<stan::is_autodiff<T3__>,
                                                      std::is_floating_point<T3__>>>*>
stan::promote_args_t<T0__, T1__, T2__, T3__>
lawless_generalized_gamma_cdf(const T0__& x, const T1__& k, const T2__& mu,
                              const T3__& sigma, std::ostream* pstream__) {
  using local_scalar_t__ = stan::promote_args_t<T0__, T1__, T2__, T3__>;
  int current_statement__ = 0;
  // suppress unused var warning
  (void) current_statement__;
  static constexpr bool propto__ = true;
  // suppress unused var warning
  (void) propto__;
  local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
  // suppress unused var warning
  (void) DUMMY_VAR__;
  //try {
    local_scalar_t__ w = DUMMY_VAR__;
    current_statement__ = 9;
    std::cout << "(x, mu, sigma) = " << "(" << x << ", " << mu << ", " << sigma << ")" << std::endl;
    w = ((stan::math::log(x) - mu) / sigma);
    current_statement__ = 10;
    if (pstream__) {
      stan::math::stan_print(pstream__, "w = ");
      stan::math::stan_print(pstream__, w);
      *(pstream__) << std::endl;
    }
    current_statement__ = 11;
    return stan::math::gamma_p(k, (k *
             stan::math::exp(((1 / stan::math::sqrt(k)) * w))));
    //} catch (const std::exception& e) {
    //stan::lang::rethrow_located(e, locations_array__[current_statement__]);
    //}
}
// vector system(vector, real, real, real, real)
template <typename T0__, typename T1__, typename T2__, typename T3__,
          typename T4__,
          stan::require_all_t<stan::is_col_vector<T0__>,
                              stan::is_vt_not_complex<T0__>,
                              stan::math::disjunction<stan::is_autodiff<T1__>,
                                                      std::is_floating_point<T1__>>,
                              stan::math::disjunction<stan::is_autodiff<T2__>,
                                                      std::is_floating_point<T2__>>,
                              stan::math::disjunction<stan::is_autodiff<T3__>,
                                                      std::is_floating_point<T3__>>,
                              stan::math::disjunction<stan::is_autodiff<T4__>,
                                                      std::is_floating_point<T4__>>>*>
Eigen::Matrix<stan::promote_args_t<stan::base_type_t<T0__>, T1__, T2__, T3__,
                T4__>,-1,1>
system(const T0__& y_arg__, const T1__& k, const T2__& mu, const T3__& sigma,
       const T4__& u, std::ostream* pstream__) {
  using local_scalar_t__ = stan::promote_args_t<stan::base_type_t<T0__>,
                             T1__, T2__, T3__, T4__>;
  int current_statement__ = 0;
  // suppress unused var warning
  (void) current_statement__;
  const auto& y = stan::math::to_ref(y_arg__);
  static constexpr bool propto__ = true;
  // suppress unused var warning
  (void) propto__;
  local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
  // suppress unused var warning
  (void) DUMMY_VAR__;
  //try {
  current_statement__ = 13;
  return (Eigen::Matrix<local_scalar_t__,-1,1>(1) <<
	  (lawless_generalized_gamma_cdf(y[0],
					 //stan::model::rvalue(
					 //y, "y",
					 //stan::model::index_uni(1)),
					 k, mu,
					 sigma, pstream__)
	   - u)).finished();
    //} catch (const std::exception& e) {
    //stan::lang::rethrow_located(e, locations_array__[current_statement__]);
    //}
}

TEST(newton_solver_test, issue_2985) {
  double k = 0.96,
    mu = 4.84,
    sigma = 0.97,
    u = u,
    scaling_step = 10e-3,
    f_tol = 10e-6;
  int max_steps = 1;
  
  double solution = std::numeric_limits<double>::quiet_NaN();
  solution = 
    stan::math::solve_newton_tol(system_variadic1_functor__(),
				 (Eigen::Matrix<double,-1,1>(1) << mu).finished(),
				 scaling_step, f_tol, max_steps, &std::cout, k, mu, sigma,
				 u)[0];
}

@syclik
Copy link
Member

syclik commented Jan 5, 2024

And I found it. The x gets passed in as NaN:

(x, mu, sigma) = (nan, 4.84, 0.97)

For some reason, this doesn't trigger an exception.

@syclik
Copy link
Member

syclik commented Jan 5, 2024

Update: I made an error in my original test. (I accidentally grouped terms.) It does go into an infinite loop!

This is due to something going to NaN. Updated test

emplate <typename T1, typename T2, typename T3, typename T4>
stan::return_type_t<T1, T2, T3, T4> lawless_generalized_gamma_cdf(T1& x, T2& k, T3& mu, T4& sigma) {
  stan::return_type_t<T1, T3, T4> w = (log(x) - mu) / sigma;
  std::cout << "gamma_p call = " << stan::math::gamma_p(k, k * exp(1 / sqrt(k) * w)) << std::endl;
  std::cout << "  k = " << k << std::endl
	    << "  k * exp(1 / sqrt(k) * w) = " << k * exp(1 / sqrt(k) * w) << std::endl
	    << "  1 / sqrt(k) * w = " << (1 / sqrt(k) * w) << std::endl;
  return stan::math::gamma_p(k, k * exp(1 / sqrt(k) * w));
}


struct issue_2985_functor {

  template <typename T1, typename T2, typename T3, typename T4, typename T5>
  inline Eigen::Matrix<stan::return_type_t<T1, T2, T3, T4, T5>, Eigen::Dynamic, 1>
  operator()(const T1& y, std::ostream* pstream__,
	     const T2& k,
	     const T3& mu, const T4& sigma,
	     const T5& u) const {
    std::cout << "inside solver: "
	      << "(y, k, mu, sigma, u) = "
	      << y << ", " 
	      << k << ", " 
	      << mu << ", " 
	      << sigma << ", " 
	      << u << ")" << std::endl;
      
    Eigen::Matrix<stan::return_type_t<T1, T2, T3, T4, T5>, Eigen::Dynamic, 1> out(1);
    
    out << (lawless_generalized_gamma_cdf(y[0], k, mu, sigma) - u);
    std::cout << "output == " << out << std::endl << std::endl;
    
    return out;
  }
};


TEST(newton_solver_test, issue_2985) {
  double k = 0.96,
    mu = 4.84,
    sigma = 0.97,
    u = 5e-5,
    scaling_step = 10e-3,
    f_tol = 10e-6;
  int max_steps = 1;

  Eigen::VectorXd x(1);
  x << mu; // initial guess

  EXPECT_THROW( {
      Eigen::MatrixXd out = stan::math::solve_newton_tol(issue_2985_functor(),
							 x,
							 scaling_step, f_tol,
							 max_steps,
							 &std::cout, k, mu, sigma, u);
    }, std::domain_error);
}

@syclik
Copy link
Member

syclik commented Jan 5, 2024

@DemetriPananos, please try this for the model and see what you get:

model_code <- "
functions {
  real lawless_generalized_gamma_cdf(real x, real k, real mu, real sigma) {
    real w = (log(x) - mu) / sigma;
    if (x < 0) {
      reject(\"x must be > 0; x = \", x, \"; this makes w = \", w);
    }
    return gamma_p(k, k * exp(1 / sqrt(k) * w));
  }
  vector system(vector y, real k, real mu, real sigma, real u) {
    return [lawless_generalized_gamma_cdf(y[1]| k, mu, sigma) - u]';
  }
}
data {
  real k;
  real mu;
  real sigma;
  real u;

  real scaling_step;
  real f_tol;
  int max_steps;
}
generated quantities {
  real solution = solve_newton_tol(system, [mu]', scaling_step, f_tol, max_steps, k, mu, sigma, u)[1];
}
"

The update is 3 lines within the lawless_generalized_gamma_cdf() function:

    if (x < 0) {
      reject(\"x must be > 0; x = \", x, \"; this makes w = \", w);
    }

When I run it, I end up with output like this:

Chain 1 Iteration: 1 / 1 [100%]  (Sampling) 
Chain 1 Exception: Exception: Exception: x must be > 0; x = -0.0213352; this makes w = nan (in '/var/folders/6_/z7l__wq90rx_3m3jgn50t1kc0000gn/T/RtmpD9Ms1j/model-976c642f4f22.stan', line 6, column 6 to column 63) (in '/var/folders/6_/z7l__wq90rx_3m3jgn50t1kc0000gn/T/RtmpD9Ms1j/model-976c642f4f22.stan', line 11, column 4 to column 68) (in '/var/folders/6_/z7l__wq90rx_3m3jgn50t1kc0000gn/T/RtmpD9Ms1j/model-976c642f4f22.stan', line 25, column 2 to column 102)
Chain 1 finished in 0.0 seconds.

@charlesm93, any thoughts on what to do? Is there a way to have kinsol stop when it encounters nan? Should we add more documentation to indicate to users that they need to code very, very defensively?

@syclik
Copy link
Member

syclik commented Jan 5, 2024

I was able to simplify the example a bit:

struct issue_2985_functor {
  template <typename T1>
  inline Eigen::Matrix<stan::return_type_t<T1>, Eigen::Dynamic, 1>
  operator()(const T1& y, std::ostream* pstream__) const {
    Eigen::Matrix<stan::return_type_t<T1>, Eigen::Dynamic, 1> out(1);
    out << log(y[0]);

    return out;
  }
};


TEST(newton_solver_test, issue_2985) {
  double scaling_step = 10e-3,
    f_tol = 10e-6;
  int max_steps = 10000000;

  Eigen::VectorXd x(1);
  x << 0; // initial guess

  Eigen::MatrixXd out = stan::math::solve_newton(issue_2985_functor(),
						 x,
						 &std::cout);
  std::cout << "out = " << out << std::endl;

  x[0] = out(0);
  std::cout << "f(out) = " << issue_2985_functor()(x, &std::cout) << std::endl;
}

different values of the starting x makes this behave differently. Something slightly positive stops almost immediately. Something near 1 finds the answer at 1. When it's something higher, say 17, it also takes a long time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants