Skip to content

Commit

Permalink
efficient multilevel MC
Browse files Browse the repository at this point in the history
  • Loading branch information
piers-hinds committed May 2, 2023
1 parent ce70818 commit ce6ab46
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 31 deletions.
20 changes: 12 additions & 8 deletions sde_mc/mlmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,18 @@ def mc_multilevel(trials, levels, solver, payoff, discounter, bs=None):
# First level
run_sum, run_sum_sq = 0, 0
trials_remaining = trials[0]
original_steps = solver.num_steps
solver.num_steps = levels[0]
while trials_remaining > 0:
next_batch_size = min(trials_remaining, bs[0])
solver.num_steps = levels[0]
paths, _ = solver.solve(bs=next_batch_size)
terminals = payoff(paths[:, solver.num_steps, :]) * discounter(solver.time_interval)
paths, _ = solver.solve(bs=next_batch_size, low_storage=True)
terminals = payoff(paths[:, -1, :]) * discounter(solver.time_interval)
run_sum += terminals.sum()
run_sum_sq += (terminals * terminals).sum()
trials_remaining -= next_batch_size

exps[0], vars[0] = mc_estimates(run_sum, run_sum_sq, trials[0])
solver.num_steps = original_steps

# Other levels
pairs = [(levels[i + 1], levels[i]) for i in range(0, len(levels) - 1)]
Expand All @@ -75,16 +77,14 @@ def mc_multilevel(trials, levels, solver, payoff, discounter, bs=None):
def get_optimal_trials(trials, levels, epsilon, solver, payoff, discounter):
"""Finds the optimal number of trials at each level for the MLMC method (for a given tolerance)"""

original_num_steps = solver.num_steps
vars = torch.zeros(len(levels))
pairs = [(levels[i + 1], levels[i]) for i in range(0, len(levels) - 1)]
step_sizes = solver.time_interval / torch.tensor(levels)
solver.num_steps = levels[0]
paths, _ = solver.solve(bs=trials)
discounted_payoffs = payoff(paths[:, solver.num_steps, :]) * discounter(solver.time_interval)
paths, _ = solver.solve(bs=trials, low_storage=True)
discounted_payoffs = payoff(paths[:, -1, :]) * discounter(solver.time_interval)
var = discounted_payoffs.var()
vars[0] = var
solver.num_steps = original_num_steps

for i, pair in enumerate(pairs):
(paths_fine, paths_coarse), _ = solver.multilevel_solve(trials, pair)
Expand All @@ -94,4 +94,8 @@ def get_optimal_trials(trials, levels, epsilon, solver, payoff, discounter):

sum_term = (vars / step_sizes).sqrt().sum()
optimal_trials = (1.96**2 / (epsilon * epsilon)) * (vars * step_sizes).sqrt() * sum_term
return optimal_trials.ceil().long().tolist()
return optimal_trials.ceil().long().tolist()


def mlmc_bs_from_trials(trials, levels, max_mem=5*10**8, dim=1):
return torch.minimum(max_mem / (dim * torch.tensor(levels)), trials).ceil().long()
139 changes: 117 additions & 22 deletions sde_mc/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,17 @@ class HestonSolver(HestonScheme, DiffusionSolver):


class JumpDiffusionSolver(SdeSolver):
def __init__(self, sde, time_interval, num_steps, device='cpu', seed=1):
def __init__(self, sde, time_interval, num_steps, device='cpu', seed=1, exact_jumps=False):
super(JumpDiffusionSolver, self).__init__(sde, time_interval, num_steps, device, seed)
self.max_jumps = max(int(self.time_interval * poisson.ppf(1 - 1/1e9, self.sde.jump_rate().sum())), 5)
self.max_jumps = max(int(self.time_interval * poisson.ppf(1 - 1 / 1e9, self.sde.jump_rate().sum())), 5)
self.exact_jumps = exact_jumps

@abstractmethod
def step(self, t, x, h, corr_normals):
pass

def add_jumps(self, t, x, jumps):
return x + self.sde.jumps(t, x, jumps)
def add_jumps(self, t, old_x, x, jumps):
return x + self.sde.jumps(t, old_x, jumps)

def sample_jump_times(self, size):
return torch.empty(size, device=self.device).exponential_(self.sde.jump_rate().sum()).cumsum(dim=1)
Expand All @@ -146,27 +147,33 @@ def sample_one_jump(self, size):
jumps = self.sde.sample_jumps([size, 1], self.device).repeat(1, self.sde.dim)
return jumps

def init_storage(self, bs, steps):
def init_storage(self, bs, steps, low_storage=False):
paths = torch.zeros(size=(bs, steps + 1, self.sde.dim), device=self.device)
if low_storage:
return paths, None, None, None, None
left_paths = torch.zeros_like(paths)
jump_paths = torch.zeros_like(paths)
time_paths = torch.zeros(size=(bs, steps + 1, 1), device=self.device) + self.time_interval
if self.sde.diffusion_struct == 'diag':
normals = torch.zeros(size=(bs, steps, self.sde.dim), device=self.device)
else:
normals = torch.zeros(size=(bs, steps, self.sde.dim, int(self.sde.brown_dim / self.sde.dim)), device=self.device)
normals = torch.zeros(size=(bs, steps, self.sde.dim, int(self.sde.brown_dim / self.sde.dim)),
device=self.device)
return paths, left_paths, time_paths, jump_paths, normals

def solve(self, bs=1, return_normals=False):
def solve(self, bs=1, return_normals=False, low_storage=False):
bs = int(bs)
h = torch.tensor(self.time_interval / self.num_steps, device=self.device)
x = self.sde.init_value.unsqueeze(0).repeat(bs, 1).to(self.device)
t = torch.zeros((bs, 1), device=self.device)

paths, left_paths, time_paths, jump_paths, normals = self.init_storage(bs, self.num_steps + self.max_jumps)
paths, left_paths, time_paths, jump_paths, normals = self.init_storage(bs, self.num_steps + self.max_jumps,
low_storage)
paths[:, 0] = x
left_paths[:, 0] = x
time_paths[:, 0] = t

if not low_storage:
left_paths[:, 0] = x
time_paths[:, 0] = t

jump_times = self.sample_jump_times(size=(bs, self.max_jumps, 1))
jump_idxs = torch.zeros_like(jump_times[:, 0, :]).long()
Expand All @@ -179,37 +186,125 @@ def solve(self, bs=1, return_normals=False):
# times and sizes of the next jump
next_jump_time = jump_times[torch.arange(bs), jump_idxs.squeeze(-1), :]

# time step is minimum of (mesh size, time to next jump, time to end of interval)
# time step is minimum of (prescribed maximum mesh size, time to next jump, time to end of interval)
h = torch.minimum(h, torch.maximum(self.time_interval - t, torch.tensor(0.)))
dt = torch.minimum(h, next_jump_time - t)

assert (next_jump_time >= t).all()
# step diffusion until the next time step
if self.sde.diffusion_struct == 'diag':
corr_normals = self.sample_corr_normals(x.shape + torch.Size([1]), h.unsqueeze(-1))
corr_normals = self.sample_corr_normals(x.shape + torch.Size([1]), dt.unsqueeze(-1))
else:
corr_normals = torch.stack([
self.sample_corr_normals(x.shape + torch.Size([1]), h.unsqueeze(-1)),
self.sample_corr_normals([x.shape[0], 1, 1], h.unsqueeze(-1), corr=False).repeat(1, x.shape[1])
], dim=-1)
self.sample_corr_normals(x.shape + torch.Size([1]), dt.unsqueeze(-1)),
self.sample_corr_normals([x.shape[0], 1, 1], dt.unsqueeze(-1), corr=False).repeat(1, x.shape[1])
], dim=-1)
old_x = x
x = self.step(t, x, dt, corr_normals)
normals[:, total_steps - 1] = corr_normals
left_paths[:, total_steps] = x
t += dt
time_paths[:, total_steps] = t
if not low_storage:
normals[:, total_steps - 1] = corr_normals
left_paths[:, total_steps] = x
time_paths[:, total_steps] = t

# add jumps if the next jump is now - could sample jumps here if storage issues
# add jumps if the next jump is now
next_jump_size = self.sample_one_jump(bs)
current_jumps = torch.where(torch.isclose(next_jump_time, t, atol=1e-12), next_jump_size,
torch.zeros_like(next_jump_size))
jump_paths[:, total_steps] = current_jumps
x = self.add_jumps(t, x, current_jumps)
if self.exact_jumps:
x = self.add_jumps(t, x, x, current_jumps)
else:
x = self.add_jumps(t, old_x, x, current_jumps)

# store in path
paths[:, total_steps] = x
if not low_storage:
jump_paths[:, total_steps] = current_jumps

# increment jump index if a jump has just happened
jump_idxs = torch.where(torch.isclose(next_jump_time, t, atol=1e-12), jump_idxs + 1, jump_idxs)
return paths, (normals, time_paths, left_paths, total_steps, jump_paths)
return paths[:, :total_steps + 1], (normals, time_paths, left_paths, total_steps, jump_paths)

def multilevel_solve(self, bs, levels, return_normals=False):
bs = int(bs)
fine, coarse = levels
factor = int(fine / coarse)

h_fine = torch.tensor(self.time_interval / fine, device=self.device)
h_coarse = factor * h_fine

t_fine = torch.zeros((bs, 1), device=self.device)
t_coarse = torch.zeros((bs, 1), device=self.device)

x_fine = self.sde.init_value.unsqueeze(0).repeat(bs, 1).to(self.device)
x_coarse = self.sde.init_value.unsqueeze(0).repeat(bs, 1).to(self.device)

paths_fine, _, _, _, _ = self.init_storage(bs, coarse + self.max_jumps, low_storage=True)
paths_coarse, _, _, _, _ = self.init_storage(bs, coarse + self.max_jumps, low_storage=True)
paths_fine[:, 0] = x_fine
paths_coarse[:, 0] = x_coarse

jump_times = self.sample_jump_times(size=(bs, self.max_jumps, 1))
jump_idxs = torch.zeros_like(jump_times[:, 0, :]).long()

total_steps_fine = 0
total_steps_coarse = 0
while torch.any(t_fine < self.time_interval):
run_sum_normals = 0

# times and sizes of the next jump
next_jump_time = jump_times[torch.arange(bs), jump_idxs.squeeze(-1), :]

# Do factor steps for the finer level - store the normals
for i in range(factor):
total_steps_fine += 1
# time step is minimum of (mesh size, time to next jump, time to end of interval)
h_fine = torch.minimum(h_fine, torch.maximum(self.time_interval - t_fine, torch.tensor(0.)))
dt_fine = torch.minimum(h_fine, next_jump_time - t_fine)
assert (next_jump_time >= t_fine).all()
# step diffusion until the next time step
if self.sde.diffusion_struct == 'diag':
corr_normals = self.sample_corr_normals(x_fine.shape + torch.Size([1]), dt_fine.unsqueeze(-1))
else:
corr_normals = torch.stack([
self.sample_corr_normals(x_fine.shape + torch.Size([1]), dt_fine.unsqueeze(-1)),
self.sample_corr_normals([x_fine.shape[0], 1, 1], dt_fine.unsqueeze(-1), corr=False).repeat(1,
x_fine.shape[
1])
], dim=-1)
old_x_fine = x_fine
x_fine = self.step(t_fine, x_fine, dt_fine, corr_normals)
t_fine += dt_fine
run_sum_normals += corr_normals

# Do one step on the coarser level
total_steps_coarse += 1
h_coarse = torch.minimum(h_coarse, torch.maximum(self.time_interval - t_coarse, torch.tensor(0.)))
dt_coarse = torch.minimum(h_coarse, next_jump_time - t_coarse)
old_x_coarse = x_coarse
x_coarse = self.step(t_coarse, x_coarse, dt_coarse, run_sum_normals) # check this
t_coarse += dt_coarse

# add jumps if the next jump is now - could sample jumps here if storage issues
assert torch.isclose(t_fine, t_coarse, atol=1e-12).all()
next_jump_size = self.sample_one_jump(bs)
current_jumps = torch.where(torch.isclose(next_jump_time, t_fine, atol=1e-12), next_jump_size,
torch.zeros_like(next_jump_size))
# jump_paths_fine[:, total_steps_fine] = current_jumps
if self.exact_jumps:
x_fine = self.add_jumps(t_fine, x_fine, x_fine, current_jumps)
x_coarse = self.add_jumps(t_coarse, x_coarse, x_coarse, current_jumps)
else:
x_fine = self.add_jumps(t_fine, old_x_fine, x_fine, current_jumps)
x_coarse = self.add_jumps(t_coarse, old_x_coarse, x_coarse, current_jumps)

# store in path
paths_fine[:, total_steps_coarse] = x_fine
paths_coarse[:, total_steps_coarse] = x_coarse

# increment jump index if a jump has just happened
jump_idxs = torch.where(torch.isclose(next_jump_time, t_fine, atol=1e-12), jump_idxs + 1, jump_idxs)
return (paths_fine[:, :total_steps_coarse + 1], paths_coarse[:, :total_steps_coarse + 1]), _


class JumpEulerSolver(EulerScheme, JumpDiffusionSolver):
Expand Down
2 changes: 1 addition & 1 deletion sde_mc/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3'
__version__ = '0.4'

0 comments on commit ce6ab46

Please sign in to comment.