Skip to content

Commit

Permalink
feat(nodes): raise on NSFW
Browse files Browse the repository at this point in the history
  • Loading branch information
psychedelicious committed May 13, 2024
1 parent 8bddcce commit 60d16fd
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 24 deletions.
2 changes: 1 addition & 1 deletion invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. NSFW images will be blurred.")
nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. If an NSFW image is encountered during generation, execution will immediately stop. If disabled, the NSFW model is never loaded.")
watermark: bool = Field(default=False, description="Watermark all images with `invisible-watermark`.")

# NODES
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/services/shared/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def save(
board_id_ = self._data.invocation.board.board_id

if self._services.configuration.nsfw_check:
image = SafetyChecker.blur_if_nsfw(image)
SafetyChecker.raise_if_nsfw(image)

if self._services.configuration.watermark:
image = InvisibleWatermark.add_watermark(image, "InvokeAI")
Expand Down
34 changes: 12 additions & 22 deletions invokeai/backend/image_util/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
configuration variable, that allows the checker to be supressed.
"""

from pathlib import Path

import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from PIL import Image, ImageFilter
from PIL import Image
from transformers import AutoFeatureExtractor

import invokeai.backend.util.logging as logger
Expand All @@ -20,10 +18,15 @@
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"


class NSFWImageException(Exception):
"""Raised when a NSFW image is detected."""

def __init__(self):
super().__init__("A potentially NSFW image has been detected.")


class SafetyChecker:
"""
Wrapper around SafetyChecker model.
"""
"""Wrapper around SafetyChecker model."""

feature_extractor = None
safety_checker = None
Expand Down Expand Up @@ -72,22 +75,9 @@ def has_nsfw_concept(cls, image: Image.Image) -> bool:
return has_nsfw_concept[0]

@classmethod
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
def raise_if_nsfw(cls, image: Image.Image) -> Image.Image:
"""Raises an exception if the image contains NSFW content."""
if cls.has_nsfw_concept(image):
logger.warning("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = cls._get_caution_img()
# Center the caution image on the blurred image
x = (blurry_image.width - caution.width) // 2
y = (blurry_image.height - caution.height) // 2
blurry_image.paste(caution, (x, y), caution)
image = blurry_image
raise NSFWImageException()

return image

@classmethod
def _get_caution_img(cls) -> Image.Image:
import invokeai.app.assets.images as image_assets

caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
return caution.resize((caution.width // 2, caution.height // 2))

0 comments on commit 60d16fd

Please sign in to comment.