Skip to content

Commit

Permalink
warmup fix
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Jun 26, 2024
1 parent 872b167 commit a66ae10
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 21 deletions.
9 changes: 9 additions & 0 deletions examples/stable-diffusion/image_to_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ def main():
type=int,
help="Number of steps to capture for profiling.",
)
parser.add_argument(
"--throughput_warmup_steps",
type=int,
default=None,
help="Number of steps to ignore for throughput calculation.",
)
args = parser.parse_args()

# Set image resolution
Expand Down Expand Up @@ -255,6 +261,9 @@ def main():
if args.bf16:
kwargs["torch_dtype"] = torch.bfloat16

if args.throughput_warmup_steps is not None:
kwargs["throughput_warmup_steps"] = args.throughput_warmup_steps

pipeline = Img2ImgPipeline.from_pretrained(
args.model_name_or_path,
**kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from optimum.utils import logging

from ....transformers.gaudi_configuration import GaudiConfig
from ....utils import HabanaProfile, speed_metrics
from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment
from ..pipeline_utils import GaudiDiffusionPipeline


Expand Down Expand Up @@ -320,18 +320,24 @@ def __call__(
t0 = time.time()
t1 = t0
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)
for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
# because compilation occurs in the first two iterations
if j == throughput_warmup_steps:
t1 = time.time()

if use_warmup_inference_steps:
t0_inf = time.time()
latents_batch = latents_batches[0]
latents_batches = torch.roll(latents_batches, shifts=-1, dims=0)
image_embeddings_batch = image_embeddings_batches[0]
image_embeddings_batches = torch.roll(image_embeddings_batches, shifts=-1, dims=0)
for i in range(len(timesteps)):
if use_warmup_inference_steps and i == throughput_warmup_steps:
t1_inf = time.time()
t1 += t1_inf - t0_inf
t = timesteps[0]
timesteps = torch.roll(timesteps, shifts=-1, dims=0)
# expand the latents if we are doing classifier free guidance
Expand All @@ -354,11 +360,14 @@ def __call__(
self.htcore.mark_step()

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents_batch)
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents_batch)
hb_profiler.step()
if use_warmup_inference_steps:
t1 = warmup_inference_steps_time_adjustment(
t1, t1_inf, num_inference_steps, throughput_warmup_steps
)
if not output_type == "latent":
image = self.vae.decode(latents_batch / self.vae.config.scaling_factor, return_dict=False)[0]
else:
Expand All @@ -373,7 +382,7 @@ def __call__(
split=speed_metrics_prefix,
start_time=t0,
num_samples=num_batches * batch_size
if t1 == t0
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
start_time_after_warmup=t1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from optimum.utils import logging

from ....transformers.gaudi_configuration import GaudiConfig
from ....utils import HabanaProfile, speed_metrics
from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment
from ..pipeline_utils import GaudiDiffusionPipeline


Expand Down Expand Up @@ -403,12 +403,16 @@ def __call__(
t0 = time.time()
t1 = t0
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)
for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
# because compilation occurs in the first two iterations
if j == throughput_warmup_steps:
t1 = time.time()
if use_warmup_inference_steps:
t0_inf = time.time()

latents_batch = latents_batches[0]
latents_batches = torch.roll(latents_batches, shifts=-1, dims=0)
Expand All @@ -418,6 +422,9 @@ def __call__(
prompt_embeds_batches = torch.roll(prompt_embeds_batches, shifts=-1, dims=0)

for i in range(len(timesteps)):
if use_warmup_inference_steps and i == throughput_warmup_steps:
t1_inf = time.time()
t1 += t1_inf - t0_inf
t = timesteps[0]
timesteps = torch.roll(timesteps, shifts=-1, dims=0)
# expand the latents if we are doing classifier free guidance
Expand Down Expand Up @@ -470,11 +477,14 @@ def __call__(
image_latents_batch = callback_outputs.pop("image_latents", image_latents_batch)

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents_batch)
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents_batch)
hb_profiler.step()
if use_warmup_inference_steps:
t1 = warmup_inference_steps_time_adjustment(
t1, t1_inf, num_inference_steps, throughput_warmup_steps
)

if not output_type == "latent":
image = self.vae.decode(latents_batch / self.vae.config.scaling_factor, return_dict=False)[0]
Expand All @@ -490,7 +500,7 @@ def __call__(
split=speed_metrics_prefix,
start_time=t0,
num_samples=num_batches * batch_size
if t1 == t0
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
start_time_after_warmup=t1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from optimum.utils import logging

from ....transformers.gaudi_configuration import GaudiConfig
from ....utils import HabanaProfile, speed_metrics
from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment
from ..pipeline_utils import GaudiDiffusionPipeline
from ..stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps

Expand Down Expand Up @@ -546,11 +546,16 @@ def denoising_value_valid(dnv):

# 8.3 Denoising loop
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)
for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
# because compilation occurs in the first two iterations
if j == throughput_warmup_steps:
t1 = time.time()
if use_warmup_inference_steps:
t0_inf = time.time()

latents_batch = latents_batches[0]
latents_batches = torch.roll(latents_batches, shifts=-1, dims=0)
Expand All @@ -562,6 +567,9 @@ def denoising_value_valid(dnv):
add_time_ids_batches = torch.roll(add_time_ids_batches, shifts=-1, dims=0)

for i in range(len(timesteps)):
if use_warmup_inference_steps and i == throughput_warmup_steps:
t1_inf = time.time()
t1 += t1_inf - t0_inf
if self.interrupt:
continue
timestep = timesteps[0]
Expand Down Expand Up @@ -626,12 +634,15 @@ def denoising_value_valid(dnv):
add_time_ids_batch = torch.cat([_add_time_ids, _negative_add_time_ids])

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)

hb_profiler.step()
if use_warmup_inference_steps:
t1 = warmup_inference_steps_time_adjustment(
t1, t1_inf, num_inference_steps, throughput_warmup_steps
)

if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
Expand Down Expand Up @@ -662,7 +673,7 @@ def denoising_value_valid(dnv):
split=speed_metrics_prefix,
start_time=t0,
num_samples=num_batches * batch_size
if t1 == t0
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
start_time_after_warmup=t1,
Expand Down

0 comments on commit a66ae10

Please sign in to comment.