Skip to content

HKUNLP/diffusion-of-thoughts

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Diffusion of Thoughts

logo

examples

This repository contains code for training and evaluating the models in the paper Diffusion of Thoughts: Chain-of-Thought Reasoning in Diffusion Language Models.

  • Diffusion models have gained attention in text processing, offering many potential advantages over traditional autoregressive models. We explore the integration of diffusion models and Chain-of-Thought (CoT), a well-established technique to improve the reasoning ability in autoregressive language models.
  • We propose Diffusion-of-Thought (DoT), allowing reasoning steps to diffuse over time through the diffusion process. In contrast to traditional autoregressive language models that make decisions in a left-to-right, token-by-token manner, DoT offers more flexibility in the trade-off between computation and reasoning performance.
  • Additionally, DoT showcases promising self-correction abilities and benefits from existing reasoning-enhancing techniques like self-consistency decoding. Our findings contribute to the understanding and development of reasoning capabilities in diffusion language models.

pipeline

DoT pipeline demonstration.

Our implementation of DoT is mainly based on DiffuSeq (DiffuSeq: Sequence to Sequence Text Generation With Diffusion Models) and Plaid (Likelihood-Based Diffusion Language Models). Our tasks and dataset configurations primarily follow the guidelines set by Implicit CoT. Thanks for these excellent work!

Setup

All required packages can be found in requirements.txt. You can install them in a new environment with

conda create -n dot python=3.10
conda activate dot

git clone [email protected]:HKUNLP/diffusion-of-thoughts.git

# The following line to be replaced depending on your cuda version.
cd diffusion-of-thoughts
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html

Install NVIDIA Apex with fused kernels:

git clone https://github.com/NVIDIA/apex
cd apex
git checkout 2386a912164b0c5cfcd8be7a2b890fbac5607c82
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./

Finetuning from Plaid 1B

First download the weights from here: Plaid 1B Weights Download Page. Download data from here: 4by4/5by4/GSM8k-Aug data and put them in the ./data folder with names 4by4/5by5/gsm8k.

Extract them:

cat plaid1b_weights.tar.gz.* | tar xvzf -

Then run the following code:

# DoT
python train.py --digit --fix_src --dataset gsm8k --steps 120000 --weights_path plaid1b_weights 

# DoT-MP
python train.py --digit --fix_src --cot --dataset gsm8k --steps 31000 --weights_path plaid1b_weights 

Please refer to run_train.sh for more training commands.

Evaluation

Here are some commands for evaluation and please refer to run_eval.sh for more examples.

# dot (T=64 by default)
python3 evaluation_batch.py --weights_path outputs/gsm8k-bs128-fix_src-digit-steps120000 --fix_src --digit --dataset gsm8k --score_temp 0.5

# dot dpmsolver 
python3 evaluation_batch.py --weights_path outputs/gsm8k-bs128-fix_src-digit-steps120000 --fix_src --digit --dataset gsm8k --dpm_solver

# dot T=8
python3 evaluation_batch.py --weights_path outputs/gsm8k-bs128-fix_src-digit-steps120000 --fix_src --digit --dataset gsm8k --score_temp 0.5 --sampling_timesteps 8

# mp-dot
python3 evaluation_batch.py --weights_path outputs/gsm8k-bs128-fix_src-cot-digit-steps31000 --fix_src --digit --cot --dataset gsm8k --score_temp 0.5

Pretrained checkpoints are under policy checking and we will release them as soon as possible...

More Cases

example1

example2

example3

Citation

@article{ye2024diffusion,
  title={Diffusion of Thoughts: Chain-of-Thought Reasoning in Diffusion Language Models},
  author={Ye, Jiacheng and Gong, Shansan and Chen, Liheng and Zheng, Lin and Gao, Jiahui and Shi, Han and Wu, Chuan and Li, Zhenguo and Bi, Wei and Kong, Lingpeng},
  journal={arXiv preprint arXiv:2402.07754},
  year={2024}
}