Skip to content

benrhodes26/enhanced_discrete_gradient_mcmc

Repository files navigation

Enhanced gradient-based MCMC in discrete spaces

This repository contains the code used for the paper

Rhodes, B. and Gutmann, M. U. (2022). “Enhanced gradient-based MCMC in discrete spaces”. In: Transactions on Machine Learning Research

This repository is no longer active. However, you are welcome to email the lead author at [email protected] with questions regarding the code.

Dependencies

The env.yml file contains the necessary packages. You can easily build all dependencies with Mamba

mamba env create -f env.yml

Organisation of code

We introduce three samplers in the paper (NCG, AVG & PAVG), which are all represented as python classes. The NCG class can be found in samplers/regular_samplers.py, whilst the (P)AVG classes are in samplers/auxiliary_samplers.py. All of these classes have a .step() method that implements one step of the respective transition operator. These step methods all use caching and gpu-friendly vectorisation to reduce memory & runtime.

Each python script in the main directory (except networks.py) corresponds to an experiment reported in the paper, as we elaborate in the sections below.

Running the experiments reported in the paper

Overview - Figure 1

Figure 1 in the paper can be reproduced by running plotting_scripts/make_paper_overview_figure.py

20D Ordinal - Figure 2

The results shown in Figure 2 of the paper were created by running the following two commands

python sample_ordinal --model_name=mixture50_poly2

python sample_ordinal --model_name=mixture50_poly4

The results from the first command are placed in results/ordinal/dim20/mixture50_poly2_ssize50/TIMESTAMP, and analgously for the poly4 results.

Sparse Bayesian Linear - Figure 4

python sample_sparse_bayes_linear.py

Results are placed in results/sbl/20_100/TIMESTAMP.

Ising model - Table 1

Table 1 was generated by first running

python run_all_ising_script.py

and then running

python plotting_scripts/make_tables_for_paper.py --plot_type=ising_lattice_table

Convolutional EBM - Table 2

Table 2 was generated by first running

python run_all_neural_ising_script.py

and then running

python plotting_scripts/make_tables_for_paper.py --plot_type=usps_table

Step-size sensitivity - Figure 7 (appendix)

python run_ising_stepsize_sensitivity.py

and then running

python plotting_scripts/stepsize_sensitivity_ising_plots.py

you should find Figure 7 at results/ising_lattice_sigma0.2_stepsize_sensitivity/step_size_sensitivity.pdf

Ising model with higher-order interactions - Figure 11 & 12 (appendix)

python analyse_higher_order_sensitivity_ordinal.py

The results from this run, including Figure 12, will be placed in results/ising/sigma0.2/dim16/TIMESTAMP. If we then run

python plotting_scripts/higher_order_sensitivity_plots.py --timestamp=TIMESTAMP

then Figure 11 we be placed in the same directory as Figure 12.

Re-using/building on our code

We have tried to structure the code to be relatively re-usable.

All of the "pure" sampling scripts (Ordinal + bayesian regression + Ising model with higher-order interactions) have the same structure that can serve as a template if you wish to alter our code by e.g. adding a new MCMC sampler or defining a new target distribution. The main() of each script essentially does the following:

  • defines a target_dist which is a callable that computes the (unnormalised) log probability of the target distribution
  • defines the initial batch of chains i.e. chain_init
  • defines a methods list-of-dicts, where each dict specifies a particular MCMC operator (including any hyperparameters like step-size)
  • defines a callable metric_fn, which will be repeatedly called during sampling to compute convergence metrics. metric_fn takes in x_all, which is the history of the MCMC chains up until that point, computes arbitrary metrics of interest using that history, and saves these metrics to metrics_dict.
  • defines a callable plot_and_save_fn, which takes all of the data accumulated during sampling (such as the data stored in metrics_dict), and creates plots from them.
  • All of the above data/functions are then fed into run_sampling_procedure, which is a generic function for running a set of MCMC methods for many iterations and collating the results into figures.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages