Skip to content

Commit

Permalink
fixtypos (#31)
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang, Weiwei1 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: wenhuach21 <[email protected]>
  • Loading branch information
3 people committed Mar 8, 2024
1 parent d02f94d commit 53fb877
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 104 deletions.
88 changes: 46 additions & 42 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ python setup.py install
```
## Usage of Tuning

### On CPU/Gaudi2/ GPU
### On CPU/ Gaudi2/ GPU

```python
import torch
Expand All @@ -44,6 +44,48 @@ output_dir = "./tmp_autoround"
autoround.save_quantized(output_dir)
```



## Model inference
Please run the tuning code first



### Intel CPU
```python
# Please save the quantized model in 'itrex' format first, then refer to the ITREX tutorial for more details on inference with the INT4 model.
# (https://github.com/intel/intel-extension-for-transformers/tree/main/intel_extension_for_transformers/llm/runtime/neural_speed)
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
from transformers import AutoTokenizer

quantized_model_path = "./tmp_autoround"
scheme = "sym" if sym else "asym"
woq_config = WeightOnlyQuantConfig(
group_size=group_size, scheme=scheme, use_autoround=True
) ##only supports 4 bits currently
prompt = "There is a girl who likes adventure,"
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path, trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt").input_ids
model = AutoModelForCausalLM.from_pretrained(
quantized_model_path, quantization_config=woq_config, trust_remote_code=True, device="cpu"
)
outputs = model.generate(inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))
```


### GPU
```python
from transformers import AutoModelForCausalLM, AutoTokenizer

quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path, use_fast=True)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
```

<details>
<summary>Detailed Hyperparameters</summary>

Expand Down Expand Up @@ -71,7 +113,7 @@ autoround.save_quantized(output_dir)

- `seqlen (int)`: Data length of the sequence for tuning (default is 2048).

- `bs (int)`: Batch size for training (default is 8).
- `batch_size (int)`: Batch size for training (default is 8).

- `scale_dtype (str)`: The data type of quantization scale to be used (default is "float32"), different kernels have different choices.

Expand All @@ -91,7 +133,7 @@ autoround.save_quantized(output_dir)

- `weight_config (dict)`: Configuration for weight quantization (default is an empty dictionary), mainly for mixed bits or mixed precision.

- `device`: The device to be used for tuning. The default is set to None, allowing for automatic detection.
- `device`: The device to be used for tuning. The default is set to 'auto', allowing for automatic detection.

</details>

Expand Down Expand Up @@ -123,46 +165,8 @@ autoround.save_quantized(output_dir)
| MBZUAI/LaMini-GPT-124M | [example](./examples/language-modeling/) |
| EleutherAI/gpt-neo-125m | [example](./examples/language-modeling/) |
| databricks/dolly-v2-3b | [example](./examples/language-modeling/) |
| stabilityai/stablelm-base-alpha-3b | [example](./examples/language-modeling/) |


## Model inference
Please run the tuning code first


| stabilityai/stablelm-base-alpha-3b | [example](./examples/language-modeling/)

### Intel CPU
```python
# save_quantized to itrex format first
# Please read ITREX(https://github.com/intel/intel-extension-for-transformers/tree/main/intel_extension_for_transformers/llm/runtime/neural_speed) to understand the details
# currently please install neural-speed (https://github.com/intel/neural-speed) from source
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
from transformers import AutoTokenizer

quantized_model_path = "./tmp_autoround"
scheme = "sym" if sym else "asym"
woq_config = WeightOnlyQuantConfig(
group_size=group_size, scheme=scheme, use_autoround=True
) ##only supports 4 bits currently
prompt = "There is a girl who likes adventure,"
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path, trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt").input_ids
model = AutoModelForCausalLM.from_pretrained(
quantized_model_path, quantization_config=woq_config, trust_remote_code=True, device="cpu"
)
outputs = model.generate(inputs, max_new_tokens=50)
```
### GPU
```python
from transformers import AutoModelForCausalLM, AutoTokenizer

quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path, use_fast=True)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
```



Expand Down
17 changes: 10 additions & 7 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def __init__(
self.set_layerwise_config(self.weight_config)
self.optimizer = self.get_optimizer(None)
self.check_configs()
torch.set_printoptions(precision=5)
torch.set_printoptions(precision=3, sci_mode=True)

def get_optimizer(self, optimizer):
"""Returns the specified optimizer. In SignRound, we fix the optimizer.
Expand Down Expand Up @@ -644,6 +644,8 @@ def calib(self, n_samples):
split=self.dataset_split,
dataset_name=self.dataset_name,
)

self.start_time = time.time()
total_cnt = 0
for data in self.dataloader:
if data is None:
Expand Down Expand Up @@ -902,7 +904,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
last_loss = best_loss
best_iter = last_best_iter
dump_info = (
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))}"
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} "
f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
)
logger.info(dump_info)
Expand Down Expand Up @@ -1014,20 +1016,17 @@ def quantize(self):
Returns:
The quantized model and weight configurations.
"""
start_time = time.time()
# logger.info("cache block input")
block_names = get_block_names(self.model)
if len(block_names) == 0:
logger.warning("could not find blocks, exit with original model")
return

if self.amp:
self.model = self.model.to(self.amp_dtype)
if not self.low_gpu_mem_usage:
self.model = self.model.to(self.device)
inputs = self.cache_block_input(block_names[0], self.n_samples)
del self.inputs

if "input_ids" in inputs.keys():
dim = int((hasattr(self.model, "config") and "chatglm" in self.model.config.model_type))
total_samples = inputs["input_ids"].shape[dim]
Expand Down Expand Up @@ -1068,7 +1067,7 @@ def quantize(self):
self.weight_config[n]["sym"] = None

end_time = time.time()
cost_time = end_time - start_time
cost_time = end_time - self.start_time
logger.info(f"quantization tuning time {cost_time}")
## dump a summary
quantized_layers = []
Expand All @@ -1079,9 +1078,13 @@ def quantize(self):
unquantized_layers.append(n)
else:
quantized_layers.append(n)
logger.info(
summary_info = (
f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model"
)
if len(unquantized_layers) > 0:
summary_info += f", {unquantized_layers} have not been quantized"

logger.info(summary_info)
if len(unquantized_layers) > 0:
logger.info(f"Summary: {unquantized_layers} have not been quantized")

Expand Down
77 changes: 31 additions & 46 deletions examples/code-generation/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import argparse
import random
import sys

sys.path.insert(0, '../..')
from auto_round import (AutoRound,
AutoAdamRound)

parser = argparse.ArgumentParser()
import torch
import os
Expand Down Expand Up @@ -40,8 +36,9 @@
parser.add_argument("--eval_bs", default=4, type=int,
help="eval batch size")

parser.add_argument("--device", default=0, type=str,
help="device gpu int number, or 'cpu' ")
parser.add_argument("--device", default="auto", type=str,
help="The device to be used for tuning. The default is set to auto/None,"
"allowing for automatic detection. Currently, device settings support CPU, GPU, and HPU.")

parser.add_argument("--sym", action='store_true',
help=" sym quantization")
Expand Down Expand Up @@ -84,7 +81,7 @@
help="whether enable weight minmax tuning")

parser.add_argument("--deployment_device", default='fake', type=str,
help="targeted inference acceleration platform,The options are 'fake', 'cpu' and 'gpu',"
help="targeted inference acceleration platform,The options are 'fake', 'cpu' and 'gpu'."
"default to 'fake', indicating that it only performs fake quantization and won't be exported to any device.")

parser.add_argument("--scale_dtype", default='fp32',
Expand All @@ -101,31 +98,30 @@

args = parser.parse_args()
set_seed(args.seed)

tasks = args.tasks

model_name = args.model_name
if model_name[-1] == "/":
model_name = model_name[:-1]
print(model_name, flush=True)

tasks = args.tasks

if args.device == "cpu":
device_str = "cpu"
else:
device_str = f"cuda:{int(args.device)}"
from auto_round.utils import detect_device
device_str = detect_device(args.device)
torch_dtype = "auto"
if device_str == "hpu":
torch_dtype = torch.bfloat16
torch_device = torch.device(device_str)
is_glm = bool(re.search("chatglm", model_name.lower()))
is_llava = bool(re.search("llava", model_name.lower()))
if is_llava:
from transformers import LlavaForConditionalGeneration

model = LlavaForConditionalGeneration.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype="auto")
elif is_glm:

if is_glm:
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(
model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True
model_name, low_cpu_mem_usage=True, torch_dtype=torch_dtype, trust_remote_code=True
)

from auto_round import (AutoRound,
AutoAdamRound)
model = model.eval()
# align wigh GPTQ to eval ppl
if "opt" in model_name:
Expand Down Expand Up @@ -159,38 +155,27 @@
round = AutoRound
if args.adam:
round = AutoAdamRound
autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, bs=args.train_bs,
autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.train_bs,
seqlen=seqlen, n_blocks=args.n_blocks, iters=args.iters, lr=args.lr,
minmax_lr=args.minmax_lr, use_quant_input=args.use_quant_input, device=device_str,
amp=args.amp, n_samples=args.n_samples, low_gpu_mem_usage=args.low_gpu_mem_usage,
seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps,
scale_dtype=args.scale_dtype, dataset_name="mbpp", dataset_split=['train', 'validation', 'test']) ##TODO args pass
scale_dtype=args.scale_dtype, dataset="mbpp", dataset_split=['train', 'validation', 'test']) ##TODO args pass
model, q_config = autoround.quantize()
model_name = args.model_name.rstrip("/")
export_dir = args.output_dir + "/compressed_" + model_name.split('/')[-1] + "/"
if args.deployment_device == 'cpu':
autoround.export(output_dir=export_dir)
del q_config
elif args.deployment_device == 'gpu':
autoround.export(export_dir, target="auto_gptq", use_triton=True)
model.eval()
if args.device != "cpu":
torch.cuda.empty_cache()
model.eval()
output_dir = args.output_dir + "_" + model_name.split('/')[-1] + f"_w{args.bits}_g{args.group_size}"

import shutil

if os.path.exists(output_dir):
shutil.rmtree(output_dir)

if (hasattr(model, 'config') and model.config.torch_dtype is torch.bfloat16):
dtype = 'bfloat16'
pt_dtype = torch.bfloat16
else:
pt_dtype = torch.float16

model = model.to(pt_dtype)
model = model.to("cpu")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
export_dir = args.output_dir + "/" + model_name.split('/')[-1] + f"-autoround-w{args.bits}g{args.group_size}"
output_dir = args.output_dir + "/" + model_name.split('/')[-1] + f"-autoround-w{args.bits}g{args.group_size}-qdq"
deployment_device = args.deployment_device.split(',')
if 'gpu' in deployment_device:
autoround.save_quantized(f'{export_dir}-gpu', format="auto_gptq", use_triton=True, inplace=False)
if "cpu" in deployment_device:
autoround.save_quantized(output_dir=f'{export_dir}-cpu', format='itrex', inplace=False)
if "fake" in deployment_device:
model = model.to("cpu")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

7 changes: 5 additions & 2 deletions examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The transformers version required varies across different types of models. Here,
| tiiuae/falcon-7b | 4.28/4.30/4.34/4.36 |
| mosaicml/mpt-7b | 4.28/4.30/4.34/4.36 |
| bigscience/bloom-7b1 | 4.28/4.30/4.34/4.36 |
| baichuan-inc/Baichuan-7B | 4.28/4.30 |
| baichuan-inc/Baichuan2-7B-Chat | 4.36 |
| Qwen/Qwen-7B | 4.28/4.30/4.34/4.36 |
| THUDM/chatglm3-6b | 4.34/4.36 |
| mistralai/Mistral-7B-v0.1 | 4.34/4.36 |
Expand All @@ -28,12 +28,14 @@ The transformers version required varies across different types of models. Here,
| databricks/dolly-v2-3b | 4.34 |
| stabilityai/stablelm-base-alpha-3b | 4.34 |
| Intel/neural-chat-7b-v3 | 4.34/4.36 |
| rinna/bilingual-gpt-neox-4b | 4.36 |
| microsoft/phi-2 | 4.36 |


## 2. Prepare Dataset

The NeelNanda/pile-10k in huggingface is adopted as the default calibration data and will be downloaded automatically from the datasets Hub. To customize a dataset, please kindly follow our dataset code.
See more about loading [huggingface dataset](https://huggingface.co/docs/datasets/loading_datasets.html)
See more about loading [huggingface dataset](https://huggingface.co/docs/datasets/main/en/quickstart)

<br />

Expand Down Expand Up @@ -106,3 +108,4 @@ If you find SignRound useful for your research, please cite our paper:




13 changes: 6 additions & 7 deletions examples/language-modeling/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,12 @@ def get_library_version(library_name):
round = AutoAdamRound

weight_config = {}
if 'gpu' in args.deployment_device:
for n, m in model.named_modules():
if isinstance(m, torch.nn.Linear) or isinstance(m, transformers.modeling_utils.Conv1D):
if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0:
weight_config[n] = {"data_type": "fp"}
print(
f"{n} will not be quantized due to its shape not being divisible by 32, resulting in an exporting issue to autogptq")
for n, m in model.named_modules():
if isinstance(m, torch.nn.Linear) or isinstance(m, transformers.modeling_utils.Conv1D):
if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0:
weight_config[n] = {"data_type": "fp"}
print(
f"{n} will not be quantized due to its shape not being divisible by 32, resulting in an exporting issue to autogptq")

autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.train_bs,
seqlen=seqlen, n_blocks=args.n_blocks, iters=args.iters, lr=args.lr,
Expand Down
Loading

0 comments on commit 53fb877

Please sign in to comment.