Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU and M1/M2 GPU platform support #80

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ You can then try to offloading all weights to disk by
python3 -m flexgen.flex_opt --model facebook/opt-175b --percent 0 0 100 0 100 0 --offload-dir YOUR_SSD_FOLDER
```

### CPU and M1/M2 GPU platform
To run models on CPU platforms, all you need to do is to add an `--platform` entry:
```
python3 -m flexgen.flex_opt --model facebook/opt-1.3b --platform cpu
```
To run on M1/M2 platforms, [PyTorch nightly](https://pytorch.org/) is required for kernel coverage and better performance. Once you have PyTorch nightly installed, you can simply replace `cpu` with `mps:0`.

### How to set the offloading strategy and `--percent`?
We will release an automatic policy optimizer later, but now you have to manually try a few strategies.
The idea of high-throughput generation is to offload parameters and attention cache as much as possible to the CPU and disk if necessary.
Expand Down Expand Up @@ -191,7 +198,7 @@ See [flexgen/apps](flexgen/apps) for more example applications.
## Roadmap
We plan to work on the following features. Community contributions are welcome.

- [ ] Support Apple silicon M1/M2 deployment
- [x] Support Apple silicon M1/M2 deployment
- [ ] Support Colab deployment
- [ ] Support more models (BLOOM, CodeGen, GLM)
- [ ] Release the cost model and policy optimizer
59 changes: 42 additions & 17 deletions flexgen/flex_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class Policy:
act_gpu_percent: float
act_cpu_percent: float

only_cpu: bool

# Whether to overlap the I/O and compute
overlap: bool

Expand Down Expand Up @@ -101,12 +103,8 @@ def init_weight_list(weight_specs, policy, env):
home = get_choice(mid_percent * 100, dev_percents, dev_choices)
shape, dtype, filename = weight_specs[i]

if len(shape) < 2:
pin_memory = True
compress = False
else:
pin_memory = policy.pin_weight
compress = policy.compress_weight
pin_memory = policy.pin_weight
compress = policy.compress_weight

if not compress:
weight = home.allocate(shape, dtype, pin_memory=pin_memory)
Expand Down Expand Up @@ -614,10 +612,11 @@ def __init__(self,
else:
raise NotImplementedError()

# CUDA streams
self.load_weight_stream = torch.cuda.Stream()
self.load_cache_stream = torch.cuda.Stream()
self.store_cache_stream = torch.cuda.Stream()
if self.env.gpu.device_type == DeviceType.CUDA:
# CUDA streams
self.load_weight_stream = torch.cuda.Stream()
self.load_cache_stream = torch.cuda.Stream()
self.store_cache_stream = torch.cuda.Stream()

# Intermediate tensors
# The following buffers store values used
Expand Down Expand Up @@ -791,7 +790,8 @@ def compute_layer(self, i, j, k):

def sync(self):
self.env.disk.synchronize()
torch.cuda.synchronize()
if self.env.gpu.device_type == DeviceType.CUDA:
torch.cuda.synchronize()

def init_all_weights(self):
self.weight_home = array_1d(self.num_layers, ValueHolder)
Expand Down Expand Up @@ -1184,15 +1184,17 @@ def run_flexgen(args):
warmup_inputs = get_test_inputs(32, num_prompts, tokenizer)
inputs = get_test_inputs(prompt_len, num_prompts, tokenizer)

gpu = TorchDevice("cuda:0")
if args.platform == "cpu":
gpu = TorchDevice("cpu")
else:
gpu = TorchDevice(args.platform)
cpu = TorchDevice("cpu")
disk = TorchDisk(args.offload_dir)
env = ExecutionEnv(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))

disk = TorchDisk(args.offload_dir,platform=args.platform)
env = ExecutionEnv(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]), platform=args.platform)
policy = Policy(args.gpu_batch_size, args.num_gpu_batches,
args.percent[0], args.percent[1],
args.percent[2], args.percent[3],
args.percent[4], args.percent[5],
args.percent[4], args.percent[5], args.platform == "cpu",
args.overlap, args.sep_layer, args.pin_weight,
args.cpu_cache_compute, args.attn_sparsity,
args.compress_weight,
Expand All @@ -1203,7 +1205,8 @@ def run_flexgen(args):
group_dim=2, symmetric=False))
assert not (args.compress_cache and args.attn_sparsity < 1.0), "Not implemented"

opt_config = get_opt_config(args.model)
# use float32 for CPU platform
opt_config = get_opt_config(args.model, dtype=np.float32 if args.platform == "cpu" else np.float16)
cache_size = opt_config.cache_bytes(num_prompts, prompt_len + gen_len)
hidden_size = opt_config.hidden_bytes(num_prompts, prompt_len + gen_len)
print(f"model size: {opt_config.model_bytes()/GB:.3f} GB, "
Expand Down Expand Up @@ -1311,6 +1314,7 @@ def add_parser_arguments(parser):
parser.add_argument("--overlap", type=str2bool, nargs='?',
const=True, default=True)

parser.add_argument("--platform", type=str, default="cuda:0", help="use the number to specify device, the platform can also be cpu or mps")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -1319,4 +1323,25 @@ def add_parser_arguments(parser):

assert len(args.percent) == 6

if "cuda" in args.platform:
if not torch.cuda.is_available():
if torch.backends.mps.is_available():
args.platform = "mps:0"
else:
args.platform = "cpu"
print("CUDA devices not available, {} is used instead".format(args.platform))

if "mps" in args.platform:
if not torch.backends.mps.is_available():
args.platform = "cpu"
print("MPS devices not available, CPU is used instead")

if "cuda" not in args.platform:
# not clear how to enable overlap on MPS platform yet
args.overlap = False
args.pin_weight = False

if args.platform == "cpu":
args.percent = [0, 100, 0, 100, 0, 100]

run_flexgen(args)
22 changes: 12 additions & 10 deletions flexgen/opt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def get_opt_config(name, **kwargs):
name = name.split("/")[1]
name = name.lower()

dtype = kwargs["dtype"] if "dtype" in kwargs else OptConfig.dtype

# Handle opt-iml-30b and opt-iml-max-30b
if "-iml-max" in name:
arch_name = name.replace("-iml-max", "")
Expand All @@ -65,54 +67,54 @@ def get_opt_config(name, **kwargs):
if arch_name == "opt-125m":
config = OptConfig(name=name,
max_seq_len=2048, num_hidden_layers=12, n_head=12,
hidden_size=768, input_dim=768, ffn_embed_dim=768 * 4,
hidden_size=768, input_dim=768, ffn_embed_dim=768 * 4, dtype=dtype
)
elif arch_name == "opt-350m":
config = OptConfig(name=name,
max_seq_len=2048, num_hidden_layers=24, n_head=16,
hidden_size=1024, input_dim=1024, ffn_embed_dim=1024 * 4,
hidden_size=1024, input_dim=1024, ffn_embed_dim=1024 * 4, dtype=dtype
)
raise NotImplementedError("Not implemented because this model "
"has a different architecture")
elif arch_name == "opt-1.3b":
config = OptConfig(name=name,
max_seq_len=2048, num_hidden_layers=24, n_head=32,
hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4,
hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4, dtype=dtype
)
elif arch_name == "opt-2.7b":
config = OptConfig(name=name,
max_seq_len=2048, num_hidden_layers=32, n_head=32,
hidden_size=2560, input_dim=2560, ffn_embed_dim=2560 * 4,
hidden_size=2560, input_dim=2560, ffn_embed_dim=2560 * 4, dtype=dtype
)
elif arch_name == "opt-6.7b":
config = OptConfig(name=name,
max_seq_len=2048, num_hidden_layers=32, n_head=32,
hidden_size=4096, input_dim=4096, ffn_embed_dim=4096 * 4,
hidden_size=4096, input_dim=4096, ffn_embed_dim=4096 * 4, dtype=dtype
)
elif arch_name == "opt-13b":
config = OptConfig(name=name,
max_seq_len=2048, num_hidden_layers=40, n_head=40,
hidden_size=5120, input_dim=5120, ffn_embed_dim=5120 * 4,
hidden_size=5120, input_dim=5120, ffn_embed_dim=5120 * 4, dtype=dtype
)
elif arch_name == "opt-30b":
config = OptConfig(name=name,
max_seq_len=2048, num_hidden_layers=48, n_head=56,
hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4,
hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4, dtype=dtype
)
elif arch_name == "opt-66b":
config = OptConfig(name=name,
max_seq_len=2048, num_hidden_layers=64, n_head=72,
hidden_size=9216, input_dim=9216, ffn_embed_dim=9216 * 4,
hidden_size=9216, input_dim=9216, ffn_embed_dim=9216 * 4, dtype=dtype
)
elif arch_name == "opt-175b":
config = OptConfig(name=name,
max_seq_len=2048, num_hidden_layers=96, n_head=96,
hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4,
hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, dtype=dtype
)
elif arch_name == "opt-175b-stage":
config = OptConfig(name=name,
max_seq_len=2048, num_hidden_layers=24, n_head=96,
hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4,
hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, dtype=dtype
)
else:
raise ValueError(f"Invalid model name: {name}")
Expand Down