diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index 5f33c6fb7..1ac476120 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -60,11 +60,27 @@ python text_to_image_generation.py \ --bf16 ``` +### Distributed inference with multiple HPUs +Here is how to generate images with two prompts on two HPUs: +```bash +python ../gaudi_spawn.py \ + --world_size 2 text_to_image_generation.py \ + --model_name_or_path runwayml/stable-diffusion-v1-5 \ + --prompts "An image of a squirrel in Picasso style" "A shiny flying horse taking off" \ + --num_images_per_prompt 20 \ + --batch_size 4 \ + --image_save_dir /tmp/stable_diffusion_images \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 \ + --distributed +``` + > HPU graphs are recommended when generating images by batches to get the fastest possible generations. > The first batch of images entails a performance penalty. All subsequent batches will be generated much faster. > You can enable this mode with `--use_hpu_graphs`. - ### Stable Diffusion 2 [Stable Diffusion 2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion_2) can also be used to generate images with this script. Here is an example for a single prompt: @@ -109,6 +125,23 @@ python text_to_image_generation.py \ --gaudi_config Habana/stable-diffusion-2 \ --ldm3d ``` +Here is how to generate images and depth maps with two prompts on two HPUs: +```bash +python ../gaudi_spawn.py \ + --world_size 2 text_to_image_generation.py \ + --model_name_or_path "Intel/ldm3d-4c" \ + --prompts "An image of a squirrel in Picasso style" "A shiny flying horse taking off" \ + --num_images_per_prompt 10 \ + --batch_size 2 \ + --height 768 \ + --width 768 \ + --image_save_dir /tmp/stable_diffusion_images \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion-2 \ + --ldm3d \ + --distributed +``` > There are three different checkpoints for LDM3D: > - use [original checkpoint](https://huggingface.co/Intel/ldm3d) to generate outputs from the paper @@ -173,6 +206,25 @@ python text_to_image_generation.py \ --bf16 ``` +Here is how to generate SDXL images with two prompts on two HPUs: +```bash +python ../gaudi_spawn.py \ + --world_size 2 text_to_image_generation.py \ + --model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ + --prompts "Sailing ship painting by Van Gogh" "A shiny flying horse taking off" \ + --prompts_2 "Red tone" "Blue tone" \ + --negative_prompts "Low quality" "Sketch" \ + --negative_prompts_2 "Clouds" "Clouds" \ + --num_images_per_prompt 20 \ + --batch_size 8 \ + --image_save_dir /tmp/stable_diffusion_xl_images \ + --scheduler euler_discrete \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 \ + --distributed +``` > HPU graphs are recommended when generating images by batches to get the fastest possible generations. > The first batch of images entails a performance penalty. All subsequent batches will be generated much faster. > You can enable this mode with `--use_hpu_graphs`. @@ -244,6 +296,25 @@ python text_to_image_generation.py \ --bf16 ``` +Here is how to generate images conditioned by canny edge model and with two prompts on two HPUs: +```bash +pip install -r requirements.txt +python ../gaudi_spawn.py \ + --world_size 2 text_to_image_generation.py \ + --model_name_or_path runwayml/stable-diffusion-v1-5 \ + --controlnet_model_name_or_path lllyasviel/sd-controlnet-canny \ + --prompts "futuristic-looking woman" "a rusty robot" \ + --control_image https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png \ + --num_images_per_prompt 10 \ + --batch_size 4 \ + --image_save_dir /tmp/controlnet_images \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 \ + --distributed +``` + Here is how to generate images conditioned by open pose model: ```bash pip install -r requirements.txt diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index e52fc5fcb..9dd8fe51a 100755 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -20,6 +20,7 @@ import numpy as np import torch +from accelerate import PartialState from optimum.habana.diffusers import ( GaudiDDIMScheduler, @@ -220,6 +221,7 @@ def main(): default=0, help="Number of steps to capture for profiling.", ) + parser.add_argument("--distributed", action="store_true", help="Use distributed inference on multi-cards") parser.add_argument( "--unet_adapter_name_or_path", default=None, @@ -317,6 +319,26 @@ def main(): if args.bf16: kwargs["torch_dtype"] = torch.bfloat16 + negative_prompts = args.negative_prompts + if args.distributed: + distributed_state = PartialState() + if args.negative_prompts is not None: + with distributed_state.split_between_processes(args.negative_prompts) as negative_prompt: + negative_prompts = negative_prompt + + kwargs_common = { + "num_images_per_prompt": args.num_images_per_prompt, + "batch_size": args.batch_size, + "num_inference_steps": args.num_inference_steps, + "guidance_scale": args.guidance_scale, + "negative_prompt": negative_prompts, + "eta": args.eta, + "output_type": args.output_type, + "profiling_warmup_steps": args.profiling_warmup_steps, + "profiling_steps": args.profiling_steps, + } + + kwargs_call.update(kwargs_common) if args.throughput_warmup_steps is not None: kwargs_call["throughput_warmup_steps"] = args.throughput_warmup_steps @@ -334,21 +356,8 @@ def main(): # Set seed before running the model set_seed(args.seed) + kwargs_call["image"] = control_image - outputs = pipeline( - prompt=args.prompts, - image=control_image, - num_images_per_prompt=args.num_images_per_prompt, - batch_size=args.batch_size, - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - negative_prompt=args.negative_prompts, - eta=args.eta, - output_type=args.output_type, - profiling_warmup_steps=args.profiling_warmup_steps, - profiling_steps=args.profiling_steps, - **kwargs_call, - ) elif sdxl: pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( args.model_name_or_path, @@ -360,21 +369,18 @@ def main(): # Set seed before running the model set_seed(args.seed) - outputs = pipeline( - prompt=args.prompts, - prompt_2=args.prompts_2, - num_images_per_prompt=args.num_images_per_prompt, - batch_size=args.batch_size, - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - negative_prompt=args.negative_prompts, - negative_prompt_2=args.negative_prompts_2, - eta=args.eta, - output_type=args.output_type, - profiling_warmup_steps=args.profiling_warmup_steps, - profiling_steps=args.profiling_steps, - **kwargs_call, - ) + prompts_2 = args.prompts_2 + negative_prompts_2 = args.negative_prompts_2 + if args.distributed and args.prompts_2 is not None: + with distributed_state.split_between_processes(args.prompts_2) as prompt_2: + prompts_2 = prompt_2 + if args.distributed and args.negative_prompts_2 is not None: + with distributed_state.split_between_processes(args.negative_prompts_2) as negative_prompt_2: + negative_prompts_2 = negative_prompt_2 + + kwargs_call["prompt_2"] = prompts_2 + kwargs_call["negative_prompt_2"] = negative_prompts_2 + else: pipeline = GaudiStableDiffusionPipeline.from_pretrained( args.model_name_or_path, @@ -394,28 +400,26 @@ def main(): pipeline.text_encoder = pipeline.text_encoder.merge_and_unload() set_seed(args.seed) - outputs = pipeline( - prompt=args.prompts, - num_images_per_prompt=args.num_images_per_prompt, - batch_size=args.batch_size, - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - negative_prompt=args.negative_prompts, - eta=args.eta, - output_type=args.output_type, - profiling_warmup_steps=args.profiling_warmup_steps, - profiling_steps=args.profiling_steps, - **kwargs_call, - ) + if args.distributed: + with distributed_state.split_between_processes(args.prompts) as prompt: + outputs = pipeline(prompt=prompt, **kwargs_call) + else: + outputs = pipeline(prompt=args.prompts, **kwargs_call) # Save the pipeline in the specified directory if not None if args.pipeline_save_dir is not None: - pipeline.save_pretrained(args.pipeline_save_dir) + save_dir = args.pipeline_save_dir + if args.distributed: + save_dir = f"{args.pipeline_save_dir}_{distributed_state.process_index}" + pipeline.save_pretrained(save_dir) # Save images in the specified directory if not None and if they are in PIL format if args.image_save_dir is not None: if args.output_type == "pil": image_save_dir = Path(args.image_save_dir) + if args.distributed: + image_save_dir = Path(f"{image_save_dir}_{distributed_state.process_index}") + image_save_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Saving images in {image_save_dir.resolve()}...") if args.ldm3d: