From 8993c64a7a800c29133029e2f22a22c33b765eaa Mon Sep 17 00:00:00 2001 From: yuanwu Date: Tue, 25 Jun 2024 12:33:26 +0000 Subject: [PATCH] Use kwargs_call Signed-off-by: yuanwu --- .../text_to_image_generation.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index 70280666f..9dd8fe51a 100755 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -243,10 +243,10 @@ def main(): args = parser.parse_args() # Set image resolution - res = {} + kwargs_call = {} if args.width > 0 and args.height > 0: - res["width"] = args.width - res["height"] = args.height + kwargs_call["width"] = args.width + kwargs_call["height"] = args.height # ControlNet if args.control_image is not None: @@ -326,7 +326,7 @@ def main(): with distributed_state.split_between_processes(args.negative_prompts) as negative_prompt: negative_prompts = negative_prompt - infer_kwargs = { + kwargs_common = { "num_images_per_prompt": args.num_images_per_prompt, "batch_size": args.batch_size, "num_inference_steps": args.num_inference_steps, @@ -338,8 +338,9 @@ def main(): "profiling_steps": args.profiling_steps, } + kwargs_call.update(kwargs_common) if args.throughput_warmup_steps is not None: - infer_kwargs["throughput_warmup_steps"] = args.throughput_warmup_steps + kwargs_call["throughput_warmup_steps"] = args.throughput_warmup_steps # Generate images if args.control_image is not None: @@ -355,7 +356,7 @@ def main(): # Set seed before running the model set_seed(args.seed) - infer_kwargs["image"] = control_image + kwargs_call["image"] = control_image elif sdxl: pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( @@ -377,8 +378,8 @@ def main(): with distributed_state.split_between_processes(args.negative_prompts_2) as negative_prompt_2: negative_prompts_2 = negative_prompt_2 - infer_kwargs["prompt_2"] = prompts_2 - infer_kwargs["negative_prompt_2"] = negative_prompts_2 + kwargs_call["prompt_2"] = prompts_2 + kwargs_call["negative_prompt_2"] = negative_prompts_2 else: pipeline = GaudiStableDiffusionPipeline.from_pretrained( @@ -401,9 +402,9 @@ def main(): if args.distributed: with distributed_state.split_between_processes(args.prompts) as prompt: - outputs = pipeline(prompt=prompt, **infer_kwargs, **res) + outputs = pipeline(prompt=prompt, **kwargs_call) else: - outputs = pipeline(prompt=args.prompts, **infer_kwargs, **res) + outputs = pipeline(prompt=args.prompts, **kwargs_call) # Save the pipeline in the specified directory if not None if args.pipeline_save_dir is not None: