Skip to content

QwenLM/Qwen-Audio

Repository files navigation

中文  |   English  




Qwen-Audio 🤖 | 🤗  | Qwen-Audio-Chat 🤖 | 🤗  |    Demo 🤖 | 🤗 
  Homepage  |   Paper   |    WeChat   |   Discord  



PWC PWC PWC PWC PWC PWC
PWC
PWC
PWC PWC PWC

Qwen-Audio (Qwen Large Audio Language Model) is the multimodal version of the large model series, Qwen (abbr. Tongyi Qianwen), proposed by Alibaba Cloud. Qwen-Audio accepts diverse audio (human speech, natural sound, music and song) and text as inputs, outputs text. The contribution of Qwen-Audio include:

  • Fundamental audio models: Qwen-Audio is a fundamental multi-task audio-language model that supports various tasks, languages, and audio types, serving as a universal audio understanding model. Building upon Qwen-Audio, we develop Qwen-Audio-Chat through instruction fine-tuning, enabling multi-turn dialogues and supporting diverse audio-oriented scenarios.
  • Multi-task learning framework for all types of audios: To scale up audio-language pre-training, we address the challenge of variation in textual labels associated with different datasets by proposing a multi-task training framework, enabling knowledge sharing and avoiding one-to-many interference. Our model incorporates more than 30 tasks and extensive experiments show the model achieves strong performance.
  • Strong Performance: Experimental results show that Qwen-Audio achieves impressive performance across diverse benchmark tasks without requiring any task-specific fine-tuning, surpassing its counterparts. Specifically, Qwen-Audio achieves state-of-the-art results on the test set of Aishell1, cochlscene, ClothoAQA, and VocalSound.
  • Flexible multi-run chat from audio and text input: Qwen-Audio supports multiple-audio analysis, sound understanding and reasoning, music appreciation, and tool usage.


We release two models of the Qwen-Audio series soon:

  • Qwen-Audio: The pre-trained multi-task audio understanding model uses Qwen-7B as the initialization of the LLM, and Whisper-large-v2 as the initialization of the audio encoder.
  • Qwen-Audio-Chat: A multimodal LLM-based AI assistant, which is trained with alignment techniques. Qwen-Audio-Chat supports more flexible interaction, such as multiple audio inputs, multi-round question answering, and creative capabilities.

News and Updates

  • 2023.11.30 🔥 We have released the checkpoints of both Qwen-Audio and Qwen-Audio-Chat on ModelScope and Hugging Face.
  • 2023.11.15 🎉 We released a paper for details about Qwen-Audio and Qwen-Audio-Chat model, including training details and model performance.

Evaluation

We evaluated the Qwen-Audio's abilities on 12 standard benchmarks as follows:

The below is the overal performance:

The details of evaluation are as follows:

Automatic Speech Recognition

Dataset Model Results (WER)
dev-clean dev-othoer test-clean test-other
Librispeech SpeechT5 2.1 5.5 2.4 5.8
SpeechNet - - 30.7 -
SLM-FT - - 2.6 5.0
SALMONN - - 2.1 4.9
Qwen-Audio 1.8 4.0 2.0 4.2
Dataset Model Results (WER)
dev test
Aishell1 MMSpeech-base 2.0 2.1
MMSpeech-large 1.6 1.9
Paraformer-large - 2.0
Qwen-Audio 1.2 (SOTA) 1.3 (SOTA)
Dataset Model Results (WER)
Mic iOS Android
Aishell2 MMSpeech-base 4.5 3.9 4.0
Paraformer-large - 2.9 -
Qwen-Audio 3.3 3.1 3.3

Soeech-to-text Translation

Dataset Model Results (BLUE)
en-de de-en en-zh zh-en es-en fr-en it-en
CoVoST2 SALMMON 18.6 - 33.1 - - - -
SpeechLLaMA - 27.1 - 12.3 27.9 25.2 25.9
BLSP 14.1 - - - - - -
Qwen-Audio 25.1 33.9 41.5 15.7 39.7 38.5 36.0

Automatic Audio Caption

Dataset Model Results
CIDER SPICE SPIDEr
Clotho Pengi 0.416 0.126 0.271
Qwen-Audio 0.441 0.136 0.288

Speech Recognition with Word-level Timestamp

Dataset Model AAC (ms)
Industrial Data Force-aligner 60.3
Paraformer-large-TP 65.3
Qwen-Audio 51.5 (SOTA)

Automatic Scene Classification

Dataset Model ACC
Cochlscene Cochlscene 0.669
Qwen-Audio 0.795 (SOTA)
TUT2017 Pengi 0.353
Qwen-Audio 0.649

Speech Emotion Recognition

Dataset Model ACC
Meld WavLM-large 0.542
Qwen-Audio 0.557

Audio Question & Answer

Dataset Model Results
ACC ACC (binary)
ClothoAQA ClothoAQA 0.542 0.627
Pengi - 0.645
Qwen-Audio 0.579 0.749

Vocal Sound Classification

Dataset Model ACC
VocalSound CLAP 0.4945
Pengi 0.6035
Qwen-Audio 0.9289 (SOTA)

Music Note Analysis

Dataset Model NS. Qualities (MAP) NS. Instrument (ACC)
NSynth Pengi 0.3860 0.5007
Qwen-Audio 0.4742 0.7882

We have provided all evaluation scripts to reproduce our results. Please refer to eval_audio/EVALUATION.md for details.

Evaluation of Chat

To evaluate the chat abilities of Qwen-Audio-Chat, we provide TUTORIAL and demo for users.

Requirements

  • python 3.8 and above
  • pytorch 1.12 and above, 2.0 and above are recommended
  • CUDA 11.4 and above are recommended (this is for GPU users)
  • FFmpeg

Quickstart

Below, we provide simple examples to show how to use Qwen-Audio and Qwen-Audio-Chat with 🤖 ModelScope and 🤗 Transformers.

Before running the code, make sure you have setup the environment and installed the required packages. Make sure you meet the above requirements, and then install the dependent libraries.

pip install -r requirements.txt

Now you can start with ModelScope or Transformers. For more usage, please refer to the tutorial. Qwen-Audio models currently perform best with audio clips under 30 seconds.

🤗 Transformers

To use Qwen-Audio-Chat for the inference, all you need to do is to input a few lines of codes as demonstrated below. However, please make sure that you are using the latest code.

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torch
torch.manual_seed(1234)

# Note: The default behavior now has injection attack prevention off.
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-Audio-Chat", trust_remote_code=True)

# use bf16
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-Audio-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
# use fp16
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-Audio-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
# use cpu only
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-Audio-Chat", device_map="cpu", trust_remote_code=True).eval()
# use cuda device
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-Audio-Chat", device_map="cuda", trust_remote_code=True).eval()

# Specify hyperparameters for generation (No need to do this if you are using transformers>4.32.0)
# model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-Audio-Chat", trust_remote_code=True)

# 1st dialogue turn
query = tokenizer.from_list_format([
    {'audio': 'assets/audio/1272-128104-0000.flac'}, # Either a local path or an url
    {'text': 'what does the person say?'},
])
response, history = model.chat(tokenizer, query=query, history=None)
print(response)
# The person says: "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel".

# 2nd dialogue turn
response, history = model.chat(tokenizer, 'Find the start time and end time of the word "middle classes"', history=history)
print(response)
# The word "middle classes" starts at <|2.33|> seconds and ends at <|3.26|> seconds.

Running Qwen-Audio pretrained base model is also simple.

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torch
torch.manual_seed(1234)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-Audio", trust_remote_code=True)

# use bf16
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-Audio", device_map="auto", trust_remote_code=True, bf16=True).eval()
# use fp16
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-Audio", device_map="auto", trust_remote_code=True, fp16=True).eval()
# use cpu only
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-Audio", device_map="cpu", trust_remote_code=True).eval()
# use cuda device
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-Audio", device_map="cuda", trust_remote_code=True).eval()

# Specify hyperparameters for generation (No need to do this if you are using transformers>4.32.0)
# model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-Audio", trust_remote_code=True)
audio_url = "assets/audio/1272-128104-0000.flac"
sp_prompt = "<|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>"
query = f"<audio>{audio_url}</audio>{sp_prompt}"
audio_info = tokenizer.process_audio(query)
inputs = tokenizer(query, return_tensors='pt', audio_info=audio_info)
inputs = inputs.to(model.device)
pred = model.generate(**inputs, audio_info=audio_info)
response = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False,audio_info=audio_info)
print(response)
# <audio>assets/audio/1272-128104-0000.flac</audio><|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>mister quilting is the apostle of the middle classes and we are glad to welcome his gospel<|endoftext|>

In the event of a network issue while attempting to download model checkpoints and codes from Hugging Face, an alternative approach is to initially fetch the checkpoint from ModelScope and then load it from the local directory as outlined below:

from modelscope import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer

# Downloading model checkpoint to a local dir model_dir
model_id = 'qwen/Qwen-Audio-Chat'
revision = 'master'
model_dir = snapshot_download(model_id, revision=revision)

# Loading local checkpoints
# trust_remote_code is still set as True since we still load codes from local dir instead of transformers
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    device_map="cuda",
    trust_remote_code=True
).eval()

🤖 ModelScope

ModelScope is an opensource platform for Model-as-a-Service (MaaS), which provides flexible and cost-effective model service to AI developers. Similarly, you can run the models with ModelScope as shown below:

from modelscope import (
    snapshot_download, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
)
import torch
model_id = 'qwen/Qwen-Audio-Chat'
revision = 'master'

model_dir = snapshot_download(model_id, revision=revision)
torch.manual_seed(1234)

tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
if not hasattr(tokenizer, 'model_dir'):
    tokenizer.model_dir = model_dir
# use bf16
# model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, bf16=True).eval()
# use fp16
# model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, fp16=True).eval()
# use CPU
# model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="cpu", trust_remote_code=True).eval()
# use gpu
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True).eval()

# 1st dialogue turn
query = tokenizer.from_list_format([
    {'audio': 'assets/audio/1272-128104-0000.flac'}, # Either a local path or an url
    {'text': 'what does the person say?'},
])
response, history = model.chat(tokenizer, query=query, history=None)
print(response)
# The person says: "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel".

# 2st dialogue turn
response, history = model.chat(tokenizer, 'Find the start time and end time of the word "middle classes"', history=history)
print(response)
# The word "middle classes" starts at <|2.33|> seconds and ends at <|3.26|> seconds.

Demo

Web UI

We provide code for users to build a web UI demo. Before you start, make sure you install the following packages:

pip install -r requirements_web_demo.txt

Then run the command below and click on the generated link:

python web_demo_audio.py

FAQ

If you meet problems, please refer to FAQ and the issues first to search a solution before you launch a new issue.

We Are Hiring

If you are interested in joining us as full-time or intern, please contact us at [email protected].

License Agreement

Researchers and developers are free to use the codes and model weights of both Qwen-Audio and Qwen-Audio-Chat. We also allow their commercial use. Check our license at LICENSE for more details.

Citation

If you find our paper and code useful in your research, please consider giving a star ⭐ and citation 📝 :)

@article{Qwen-Audio,
  title={Qwen-Audio: Advancing Universal Audio Understanding via Unified Large-Scale Audio-Language Models},
  author={Chu, Yunfei and Xu, Jin and Zhou, Xiaohuan and Yang, Qian and Zhang, Shiliang and Yan, Zhijie  and Zhou, Chang and Zhou, Jingren},
  journal={arXiv preprint arXiv:2311.07919},
  year={2023}
}

Contact Us

If you are interested to leave a message to either our research team or product team, feel free to send an email to [email protected].

About

The official repo of Qwen-Audio (通义千问-Audio) chat & pretrained large audio language model proposed by Alibaba Cloud.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages