From b8bcde07407d81efc1fc14f1424dee2bb686259c Mon Sep 17 00:00:00 2001 From: yuanwu Date: Fri, 12 Apr 2024 01:30:22 +0000 Subject: [PATCH] Add the MC example Signed-off-by: yuanwu --- examples/stable-diffusion/README.md | 14 +++ examples/stable-diffusion/run_distributed.py | 91 ++++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 examples/stable-diffusion/run_distributed.py diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index accb8737f..144161ae5 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -276,3 +276,17 @@ python text_to_image_generation.py \ --use_hpu_graphs \ --gaudi_config Habana/stable-diffusion-2 ``` + +### Distributed inference with multiple HPUs + +Here is how to generate two images with two prompts on two HPUs : +```python +python ../gaudi_spawn.py \ + --world_size 2 run_distributed.py \ + --model_name_or_path runwayml/stable-diffusion-v1-5 \ + --prompts "a cat" "a dog" \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` \ No newline at end of file diff --git a/examples/stable-diffusion/run_distributed.py b/examples/stable-diffusion/run_distributed.py new file mode 100644 index 000000000..56cb53237 --- /dev/null +++ b/examples/stable-diffusion/run_distributed.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from: https://huggingface.co/docs/diffusers/en/training/distributed_inference + - Use the GaudiStableDiffusionPipeline +""" +import torch +import logging +import argparse +from accelerate import PartialState +from optimum.habana.diffusers import GaudiStableDiffusionPipeline +from optimum.habana.utils import set_seed + +logger = logging.getLogger(__name__) + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_name_or_path", + default="runwayml/stable-diffusion-v1-5", + type=str, + help="Path to pre-trained model", + ) + # Pipeline arguments + parser.add_argument( + "--prompts", + type=str, + nargs="*", + default=["a dog", "a cat"], + help="The prompt or prompts to guide the image generation.", + ) + parser.add_argument( + "--num_images_per_prompt", type=int, default=1, help="The number of images to generate per prompt." + ) + parser.add_argument("--seed", type=int, default=None, help="Random seed for initialization.") + parser.add_argument("--bf16", action="store_true", help="Whether to perform generation in bf16 precision.") + parser.add_argument( + "--gaudi_config", + type=str, + default="Habana/stable-diffusion", + help=( + "Name or path of the Gaudi configuration. In particular, it enables to specify how to apply Habana Mixed" + " Precision." + ), + ) + # HPU-specific arguments + parser.add_argument("--use_habana", action="store_true", help="Use HPU.") + parser.add_argument( + "--use_hpu_graphs", action="store_true", help="Use HPU graphs on HPU. This should lead to faster generations." + ) + args = parser.parse_args() + # Set seed before running the model + if args.seed: + logger.info("Set the random seed {}!".format(args.seed)) + set_seed(args.seed) + + kwargs = { + "use_habana": args.use_habana, + "use_hpu_graphs": args.use_hpu_graphs, + "gaudi_config": args.gaudi_config, + "torch_dtype": torch.bfloat16 if args.bf16 else None + } + print(f"kwargs={kwargs}") + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + args.model_name_or_path, use_safetensors=True, **kwargs + ) + distributed_state = PartialState() + kwargs= { + "num_images_per_prompt": args.num_images_per_prompt + } + with distributed_state.split_between_processes(args.prompts) as prompt: + outputs = pipeline(prompt, **kwargs) + for i, image in enumerate(outputs.images): + image.save(f"result_{distributed_state.process_index}_{i}.png") + +if __name__ == "__main__": + main()