Skip to content

Commit

Permalink
Update mlmc.py
Browse files Browse the repository at this point in the history
  • Loading branch information
piers-hinds committed Apr 19, 2023
1 parent 6ef7e71 commit dddafe2
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions sde_mc/mlmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def mc_multilevel(trials, levels, solver, payoff, discounter, bs=None):
while trials_remaining > 0:
next_batch_size = min(trials_remaining, bs[i + 1])
(paths_fine, paths_coarse), _ = solver.multilevel_solve(next_batch_size, pair)
terminals = discounter(solver.time_interval) * (payoff(paths_fine[:, pair[0]]) -
payoff(paths_coarse[:, pair[1]]))
terminals = discounter(solver.time_interval) * (payoff(paths_fine[:, -1]) -
payoff(paths_coarse[:, -1]))
run_sum += terminals.sum()
run_sum_sq += (terminals * terminals).sum()
trials_remaining -= next_batch_size
Expand All @@ -86,8 +86,8 @@ def get_optimal_trials(trials, levels, epsilon, solver, payoff, discounter):

for i, pair in enumerate(pairs):
(paths_fine, paths_coarse), _ = solver.multilevel_solve(trials, pair)
terminals = discounter(solver.time_interval) * (payoff(paths_fine[:, pair[0]]) -
payoff(paths_coarse[:, pair[1]]))
terminals = discounter(solver.time_interval) * (payoff(paths_fine[:, -1]) -
payoff(paths_coarse[:, -1]))
vars[i + 1] = terminals.var()

sum_term = (vars / step_sizes).sqrt().sum()
Expand Down

0 comments on commit dddafe2

Please sign in to comment.