Skip to content

Commit

Permalink
Add test_stable_diffusion_inpaint.py testcases
Browse files Browse the repository at this point in the history
Signed-off-by: yuanwu <[email protected]>
  • Loading branch information
yuanwu2017 committed Apr 26, 2024
1 parent 18a096a commit f274cca
Show file tree
Hide file tree
Showing 6 changed files with 425 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import time
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -485,14 +486,16 @@ def __call__(
).to(device=device, dtype=latents.dtype)

# 10. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 0)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, throughput_warmup_steps)
self._num_timesteps = len(timesteps)

t0 = time.time()
t1 = t0

const_timesteps = copy.deepcopy(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i in range(num_inference_steps):
for i, _ in enumerate(const_timesteps):
if self.interrupt:
continue
timestep = timesteps[0]
Expand Down Expand Up @@ -534,8 +537,8 @@ def __call__(
else:
init_mask = mask

if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
if i < len(const_timesteps) - 1:
noise_timestep = const_timesteps[i + 1]
init_latents_proper = self.scheduler.add_noise(
init_latents_proper, noise, torch.tensor([noise_timestep])
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,8 +683,7 @@ def denoising_value_valid(dnv):
self._num_timesteps = len(timesteps)
const_timesteps = copy.deepcopy(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
#for i, t in enumerate(timesteps):
for i in range(num_inference_steps):
for i, _ in enumerate(const_timesteps):
if self.interrupt:
continue

Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@
]

TESTS_REQUIRE = [
"diffusers[test] >= 0.26.0, < 0.27.0",
"psutil",
"parameterized",
"GitPython",
"optuna",
"sentencepiece",
"datasets",
"safetensors",
"scipy",
"torchsde"
]

QUALITY_REQUIRES = [
Expand Down
61 changes: 34 additions & 27 deletions tests/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,10 @@
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
from diffusers.utils.testing_utils import (
CaptureLogger,
require_torch,
torch_device,
)
from diffusers.utils.testing_utils import CaptureLogger, require_torch, torch_device


#torch_device="hpu"

def to_np(tensor):
if isinstance(tensor, torch.Tensor):
Expand Down Expand Up @@ -122,7 +120,7 @@ def test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4):
def _test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4, input_image_type="pt"):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
#pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

output_pt = pipe(
Expand All @@ -149,7 +147,7 @@ def test_pt_np_pil_inputs_equivalent(self):

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
#pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

out_input_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
Expand All @@ -168,7 +166,7 @@ def test_latents_input(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
pipe = pipe.to(torch_device)
#pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
Expand Down Expand Up @@ -201,7 +199,7 @@ def test_karras_schedulers_shape(self):
# make sure that PNDM does not need warm-up
pipe.scheduler.register_to_config(skip_prk_steps=True)

pipe.to(torch_device)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 2
Expand Down Expand Up @@ -254,7 +252,7 @@ class PipelineTesterMixin:
test_xformers_attention = True

def get_generator(self, seed):
device = torch_device if torch_device != "mps" else "cpu"
device = "cpu"
generator = torch.Generator(device).manual_seed(seed)
return generator

Expand Down Expand Up @@ -327,19 +325,19 @@ def tearDown(self):
torch.cuda.empty_cache()

def test_save_load_local(self, expected_max_difference=5e-4):
#set_seed(0)
components = self.get_dummy_components()
init_kwargs = {
"use_habana": True,
"use_hpu_graphs": True,
"gaudi_config": "Habana/stable-diffusion",
"bf16_full_eval": True
}
print(f"components.keys={components.keys()}")
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
print("torch_device")

pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

Expand All @@ -351,7 +349,6 @@ def test_save_load_local(self, expected_max_difference=5e-4):

with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)

with CaptureLogger(logger) as cap_logger:
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, **init_kwargs)

Expand Down Expand Up @@ -419,7 +416,7 @@ def _test_inference_batch_consistent(
):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
#pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)
Expand Down Expand Up @@ -478,7 +475,7 @@ def _test_inference_batch_single_identical(
if hasattr(components, "set_default_attn_processor"):
components.set_default_attn_processor()

pipe.to(torch_device)
#pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
Expand Down Expand Up @@ -528,7 +525,7 @@ def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()

pipe.to(torch_device)
#pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

generator_device = "cpu"
Expand All @@ -540,9 +537,14 @@ def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):

def test_components_function(self):
init_components = self.get_dummy_components()
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}

#init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}

pipe = self.pipeline_class(**init_components)
init_components.pop("use_habana")
init_components.pop("use_hpu_graphs")
init_components.pop("bf16_full_eval")
init_components.pop("gaudi_config")

self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
Expand Down Expand Up @@ -627,6 +629,12 @@ def test_save_load_float16(self, expected_max_diff=1e-2):
def test_save_load_optional_components(self, expected_max_difference=1e-4):
if not hasattr(self.pipeline_class, "_optional_components"):
return
init_kwargs = {
"use_habana": True,
"use_hpu_graphs": True,
"gaudi_config": "Habana/stable-diffusion",
"bf16_full_eval": True
}

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
Expand All @@ -646,7 +654,7 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):

with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, **init_kwargs)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
Expand Down Expand Up @@ -711,7 +719,7 @@ def _test_attention_slicing_forward_pass(
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
#pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

generator_device = "cpu"
Expand Down Expand Up @@ -739,7 +747,7 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
#pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

generator_device = "cpu"
Expand Down Expand Up @@ -807,7 +815,7 @@ def _test_xformers_attention_forwardGenerator_pass(
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
#pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)
Expand All @@ -833,7 +841,7 @@ def _test_xformers_attention_forwardGenerator_pass(
def test_progress_bar(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
#pipe.to(torch_device)

inputs = self.get_dummy_inputs(torch_device)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
Expand All @@ -860,7 +868,7 @@ def test_num_images_per_prompt(self):

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
#pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

batch_sizes = [1, 2]
Expand All @@ -886,7 +894,7 @@ def test_cfg(self):

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
#pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)
Expand All @@ -909,7 +917,7 @@ def test_callback_inputs(self):

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
#pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
Expand Down Expand Up @@ -974,7 +982,7 @@ def test_callback_cfg(self):

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
#pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
Expand All @@ -983,7 +991,6 @@ def test_callback_cfg(self):

def callback_increase_guidance(pipe, i, t, callback_kwargs):
pipe._guidance_scale += 1.0

return callback_kwargs

inputs = self.get_dummy_inputs(torch_device)
Expand Down
Loading

0 comments on commit f274cca

Please sign in to comment.