Skip to content

Conformal prediction for controlling monotonic risk functions. Simple accompanying PyTorch code for conformal risk control in computer vision and natural language processing.

License

Notifications You must be signed in to change notification settings

aangelopoulos/conformal-risk

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Conformal Risk Control

This is the official repository of Conformal Risk Control by Anastasios N. Angelopoulos, Stephen Bates, Adam Fisch, Lihua Lei, and Tal Schuster.

Technical background

In the risk control problem, we are given some loss function $L_i(\lambda) = \ell(X_i,Y_i,\lambda)$. For example, in multi-label classification, you can think of the loss function as the false negative proportion $L_i(\lambda) = 1 - \frac{|Y_{i} \cap C_{\lambda}(X_{i})|}{|Y_i|}$, where $C_{\lambda}(X_{i})$ is the set-valued output of a machine learning model. As $\lambda$ grows, so does the set $C_{\lambda}(X_{i})$, which shrinks the false negative proportion. We seek to choose $\hat{\lambda}$ based on the first $n$ data points to control the expected value of its loss on a new test point at some user-specified risk level $\alpha$, $$\mathbb{E}\big[L_{n+1}(\hat{\lambda})\big] \leq \alpha.$$

The conformal risk control algorithm is in core/get_lhat.py. It is 5 lines long, including the function header.

Examples

Each of the {polyps, coco, hierarchical-imagenet, qa} folders contains a worked example of conformal risk control with a different risk function. polyps does gut polyp segmentation with false negative rate control. coco does multi-label classification with false negative rate control. hierarchical-imagenet does hierarchical classification and chooses the resolution of its prediction by bounding the graph distance to an ancestor of the true label. Finally, qa controls the F1-score in open-world question answering.

Setup

For the computer vision experiments, run

  conda env create -f environment.yml
  conda activate conformal-risk

This will install all dependencies for the vision experiments.

For the question-answering task, follow the instructions in qa/README.md.

Reproducing the experiments

After setting up the environment, enter the example folder and run the appropriate risk_histogram.py file. To produce the grids of images in the paper, run the python file containing the word grid in each folder.

Citation

@article{angelopoulos2022conformal,
  title={Conformal Risk Control},
  author={Angelopoulos, Anastasios N and Bates, Stephen and Fisch, Adam and Lei, Lihua and Schuster, Tal},
  journal={arXiv preprint arXiv:2208.02814},
  year={2022}
}

About

Conformal prediction for controlling monotonic risk functions. Simple accompanying PyTorch code for conformal risk control in computer vision and natural language processing.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published