Skip to content

Commit

Permalink
Use kwargs_call
Browse files Browse the repository at this point in the history
Signed-off-by: yuanwu <[email protected]>
  • Loading branch information
yuanwu2017 committed Jun 25, 2024
1 parent 861bec8 commit 8993c64
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions examples/stable-diffusion/text_to_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ def main():
args = parser.parse_args()

# Set image resolution
res = {}
kwargs_call = {}
if args.width > 0 and args.height > 0:
res["width"] = args.width
res["height"] = args.height
kwargs_call["width"] = args.width
kwargs_call["height"] = args.height

# ControlNet
if args.control_image is not None:
Expand Down Expand Up @@ -326,7 +326,7 @@ def main():
with distributed_state.split_between_processes(args.negative_prompts) as negative_prompt:
negative_prompts = negative_prompt

infer_kwargs = {
kwargs_common = {
"num_images_per_prompt": args.num_images_per_prompt,
"batch_size": args.batch_size,
"num_inference_steps": args.num_inference_steps,
Expand All @@ -338,8 +338,9 @@ def main():
"profiling_steps": args.profiling_steps,
}

kwargs_call.update(kwargs_common)
if args.throughput_warmup_steps is not None:
infer_kwargs["throughput_warmup_steps"] = args.throughput_warmup_steps
kwargs_call["throughput_warmup_steps"] = args.throughput_warmup_steps

# Generate images
if args.control_image is not None:
Expand All @@ -355,7 +356,7 @@ def main():

# Set seed before running the model
set_seed(args.seed)
infer_kwargs["image"] = control_image
kwargs_call["image"] = control_image

elif sdxl:
pipeline = GaudiStableDiffusionXLPipeline.from_pretrained(
Expand All @@ -377,8 +378,8 @@ def main():
with distributed_state.split_between_processes(args.negative_prompts_2) as negative_prompt_2:
negative_prompts_2 = negative_prompt_2

infer_kwargs["prompt_2"] = prompts_2
infer_kwargs["negative_prompt_2"] = negative_prompts_2
kwargs_call["prompt_2"] = prompts_2
kwargs_call["negative_prompt_2"] = negative_prompts_2

else:
pipeline = GaudiStableDiffusionPipeline.from_pretrained(
Expand All @@ -401,9 +402,9 @@ def main():

if args.distributed:
with distributed_state.split_between_processes(args.prompts) as prompt:
outputs = pipeline(prompt=prompt, **infer_kwargs, **res)
outputs = pipeline(prompt=prompt, **kwargs_call)
else:
outputs = pipeline(prompt=args.prompts, **infer_kwargs, **res)
outputs = pipeline(prompt=args.prompts, **kwargs_call)

# Save the pipeline in the specified directory if not None
if args.pipeline_save_dir is not None:
Expand Down

0 comments on commit 8993c64

Please sign in to comment.