Skip to content

Diffusion Classifier leverages pretrained diffusion models to perform zero-shot classification without additional training

Notifications You must be signed in to change notification settings

Awj2021/diffusion-classifier

 
 

Repository files navigation

Your Diffusion Model is Secretly a Zero-Shot Classifier

arXiv Website

This is the official implementation of the ICCV 2023 paper Your Diffusion Model is Secretly a Zero-Shot Classifier by Alexander Li, Mihir Prabhudesai, Shivam Duggal, Ellis Brown, and Deepak Pathak.

Abstract

The recent wave of large-scale text-to-image diffusion models has dramatically increased our text-based image generation abilities. These models can generate realistic images for a staggering variety of prompts and exhibit impressive compositional generalization abilities. Almost all use cases thus far have solely focused on sampling; however, diffusion models can also provide conditional density estimates, which are useful for tasks beyond image generation. In this paper, we show that the density estimates from large-scale text-to-image diffusion models like Stable Diffusion can be leveraged to perform zero-shot classification without any additional training. Our generative approach to classification, which we call Diffusion Classifier, attains strong results on a variety of benchmarks and outperforms alternative methods of extracting knowledge from diffusion models. Although a gap remains between generative and discriminative approaches on zero-shot recognition tasks, our diffusion-based approach has significantly stronger multimodal compositional reasoning ability than competing discriminative approaches. Finally, we use Diffusion Classifier to extract standard classifiers from class-conditional diffusion models trained on ImageNet. Our models achieve strong classification performance using only weak augmentations and exhibit qualitatively better "effective robustness" to distribution shift. Overall, our results are a step toward using generative over discriminative models for downstream tasks.

Installation

Create a conda environment with the following command:

conda env create -f environment.yml

If this takes too long, conda config --set solver libmamba sets conda to use the libmamba solver and could speed up installation.

Zero-shot Classification with Stable Diffusion

python eval_prob_adaptive.py --dataset cifar10 --split test --n_trials 1 \
  --to_keep 5 1 --n_samples 50 500 --loss l1 \
  --prompt_path prompts/cifar10_prompts.csv

This command reads potential prompts from a csv file and evaluates the epsilon prediction loss for each prompt using Stable Diffusion. This should work on a variety of GPUs, from as small as a 2080Ti or 3080 to as large as a 3090 or A6000. Losses are saved separately for each test image in the log directory. For the command above, the log directory is data/cifar10/v2-0_1trials_5_1keep_50_500samples_l1. Accuracy can be computed by running:

python scripts/print_acc.py data/cifar10/v2-0_1trials_5_1keep_50_500samples_l1

Commands to run Diffusion Classifier on each dataset are here. If evaluation on your use case is taking too long, there are a few options:

  1. Parallelize evaluation across multiple workers. Try using the --n_workers and --worker_idx flags.
  2. Play around with the evaluation strategy (e.g. --n_samples and --to_keep).
  3. Evaluate on a smaller subset of the dataset. Saving a npy array of test set indices and using the --subset_path flag can be useful for this.

Evaluating on your own dataset

  1. Create a csv file with the prompts that you want to evaluate, making sure to match up the correct prompts with the correct class labels. See scripts/write_cifar10_prompts.py for an example. Note that you can use multiple prompts per class.
  2. Run the command above, changing the --dataset and --prompt_path flags to match your use case.
  3. Play around with the evaluation strategy on a small subset of the dataset to reduce evaluation time.

Standard ImageNet Classification with Class-conditional Diffusion Models

Additional installations

Within the diffusion-classifier folder, download the DiT repository

git clone [email protected]:facebookresearch/DiT.git

Running Diffusion Classifier

First, save a consistent set of noise (epsilon) that will be used for all image-class pairs:

python scripts/save_noise.py --img_size 256

Then, compute and save the epsilon-prediction error for each class:

python eval_prob_dit.py  --dataset imagenet --split test \
  --noise_path noise_256.pt --randomize_noise \
  --batch_size 32 --cls CLS --t_interval 4 --extra dit256 --save_vb

For example, for ImageNet, this would need to be run with CLS from 0 to 999. This is currently a very expensive process, so we recommend using the --subset_path command to evaluate on a smaller subset of the dataset. We also plan on releasing an adaptive version that greatly reduces the computation time per test image.

Finally, compute the accuracy using the saved errors:

python scripts/print_dit_acc.py data/imagenet_dit256 --dataset imagenet

We show the commands to run DiT on all ImageNet variants here.

Compositional Reasoning on Winoground with Stable Diffusion

To run Diffusion Classifier on Winoground: First, save a consistent set of noise (epsilon) that will be used for all image-caption pairs:

python scripts/save_noise.py --img_size 512

Then, evaluate on Winoground:

python run_winoground.py --model sd --version 2-0 --t_interval 1 --batch_size 32 --noise_path noise_512.pt --randomize_noise --interpolation bicubic

To run CLIP or OpenCLIP baselines:

python run_winoground.py --model clip --version ViT-L/14
python run_winoground.py --model openclip --version ViT-H-14

Citation

If you find this work useful in your research, please cite:

@misc{li2023diffusion,
      title={Your Diffusion Model is Secretly a Zero-Shot Classifier}, 
      author={Alexander C. Li and Mihir Prabhudesai and Shivam Duggal and Ellis Brown and Deepak Pathak},
      year={2023},
      eprint={2303.16203},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

About

Diffusion Classifier leverages pretrained diffusion models to perform zero-shot classification without additional training

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%