Skip to content

Commit

Permalink
Normalize frame processor to clear them all via UI
Browse files Browse the repository at this point in the history
  • Loading branch information
henryruhs committed Aug 13, 2023
1 parent e4936df commit 270139b
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 38 deletions.
24 changes: 12 additions & 12 deletions roop/processors/frame/__modules__/face_enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,31 @@
from roop.typing import Frame, Face
from roop.utilities import conditional_download, resolve_relative_path, is_image, is_video

FACE_ENHANCER = None
FRAME_PROCESSOR = None
THREAD_SEMAPHORE = threading.Semaphore()
THREAD_LOCK = threading.Lock()
NAME = 'ROOP.PROCESSORS.FRAME.FACE_ENHANCER'
NAME = 'ROOP.FRAME_PROCESSOR.FACE_ENHANCER'


def get_face_enhancer() -> Any:
global FACE_ENHANCER
def get_frame_processor() -> Any:
global FRAME_PROCESSOR

with THREAD_LOCK:
if FACE_ENHANCER is None:
if FRAME_PROCESSOR is None:
model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
# todo: set models path -> https://github.com/TencentARC/GFPGAN/issues/399
FACE_ENHANCER = GFPGANer(
FRAME_PROCESSOR = GFPGANer(
model_path=model_path,
upscale=1,
device=frame_processors.get_device()
)
return FACE_ENHANCER
return FRAME_PROCESSOR


def clear_face_enhancer() -> None:
global FACE_ENHANCER
def clear_frame_processor() -> None:
global FRAME_PROCESSOR

FACE_ENHANCER = None
FRAME_PROCESSOR = None


def pre_check() -> bool:
Expand All @@ -51,7 +51,7 @@ def pre_start() -> bool:


def post_process() -> None:
clear_face_enhancer()
clear_frame_processor()


def enhance_face(target_face: Face, temp_frame: Frame) -> Frame:
Expand All @@ -65,7 +65,7 @@ def enhance_face(target_face: Face, temp_frame: Frame) -> Frame:
crop_frame = temp_frame[start_y:end_y, start_x:end_x]
if crop_frame.size:
with THREAD_SEMAPHORE:
_, _, crop_frame = get_face_enhancer().enhance(
_, _, crop_frame = get_frame_processor().enhance(
crop_frame,
paste_back=True
)
Expand Down
30 changes: 17 additions & 13 deletions roop/processors/frame/__modules__/face_swapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@
from roop.typing import Face, Frame
from roop.utilities import conditional_download, resolve_relative_path, is_image, is_video

FACE_SWAPPER = None
FRAME_PROCESSOR = None
THREAD_LOCK = threading.Lock()
NAME = 'ROOP.PROCESSORS.FRAME.FACE_SWAPPER'
NAME = 'ROOP.FRAME_PROCESSOR..FACE_SWAPPER'


def get_face_swapper() -> Any:
global FACE_SWAPPER
def get_frame_processor() -> Any:
global FRAME_PROCESSOR

with THREAD_LOCK:
if FACE_SWAPPER is None:
if FRAME_PROCESSOR is None:
model_path = resolve_relative_path('../models/inswapper_128.onnx')
FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers)
return FACE_SWAPPER
FRAME_PROCESSOR = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers)
return FRAME_PROCESSOR


def clear_face_swapper() -> None:
global FACE_SWAPPER
def clear_frame_processor() -> None:
global FRAME_PROCESSOR

FACE_SWAPPER = None
FRAME_PROCESSOR = None


def pre_check() -> bool:
Expand All @@ -52,12 +52,12 @@ def pre_start() -> bool:


def post_process() -> None:
clear_face_swapper()
clear_frame_processor()
clear_face_reference()


def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
return get_face_swapper().get(temp_frame, target_face, source_face, paste_back=True)
return get_frame_processor().get(temp_frame, target_face, source_face, paste_back=True)


def process_frame(source_face: Face, reference_face: Face, temp_frame: Frame) -> Frame:
Expand Down Expand Up @@ -94,8 +94,12 @@ def process_image(source_path: str, target_path: str, output_path: str) -> None:


def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
conditional_set_face_reference(temp_frame_paths)
frame_processors.process_video(source_path, temp_frame_paths, process_frames)


def conditional_set_face_reference(temp_frame_paths: List[str]) -> None:
if 'reference' in roop.globals.face_recognition and not get_face_reference():
reference_frame = cv2.imread(temp_frame_paths[roop.globals.reference_frame_number])
reference_face = get_one_face(reference_frame, roop.globals.reference_face_position)
set_face_reference(reference_face)
frame_processors.process_video(source_path, temp_frame_paths, process_frames)
24 changes: 12 additions & 12 deletions roop/processors/frame/__modules__/frame_enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
from roop.typing import Frame, Face
from roop.utilities import conditional_download, resolve_relative_path

FRAME_ENHANCER = None
FRAME_PROCESSOR = None
THREAD_SEMAPHORE = threading.Semaphore()
THREAD_LOCK = threading.Lock()
NAME = 'ROOP.PROCESSORS.FRAME.FRAME_ENHANCER'
NAME = 'ROOP.FRAME_PROCESSOR.FRAME_ENHANCER'


def get_frame_enhancer() -> Any:
global FRAME_ENHANCER
def get_frame_processor() -> Any:
global FRAME_PROCESSOR

with THREAD_LOCK:
if FRAME_ENHANCER is None:
if FRAME_PROCESSOR is None:
model_path = resolve_relative_path('../models/RealESRGAN_x4plus.pth')
FRAME_ENHANCER = RealESRGANer(
FRAME_PROCESSOR = RealESRGANer(
model_path=model_path,
model=RRDBNet(
num_in_ch=3,
Expand All @@ -36,13 +36,13 @@ def get_frame_enhancer() -> Any:
pre_pad=0,
scale=4
)
return FRAME_ENHANCER
return FRAME_PROCESSOR


def clear_frame_enhancer() -> None:
global FRAME_ENHANCER
def clear_frame_processor() -> None:
global FRAME_PROCESSOR

FRAME_ENHANCER = None
FRAME_PROCESSOR = None


def pre_check() -> bool:
Expand All @@ -56,12 +56,12 @@ def pre_start() -> bool:


def post_process() -> None:
clear_frame_enhancer()
clear_frame_processor()


def enhance_frame(temp_frame: Frame) -> Frame:
with THREAD_SEMAPHORE:
temp_frame, _ = get_frame_enhancer().enhance(temp_frame, outscale=1)
temp_frame, _ = get_frame_processor().enhance(temp_frame, outscale=1)
return temp_frame


Expand Down
6 changes: 5 additions & 1 deletion roop/processors/frame/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

FRAME_PROCESSORS_MODULES: List[ModuleType] = []
FRAME_PROCESSORS_METHODS = [
'get_frame_processor',
'clear_frame_processor',
'pre_check',
'pre_start',
'process_frame',
Expand Down Expand Up @@ -49,6 +51,8 @@ def get_frame_processors_modules(frame_processors: List[str]) -> List[ModuleType
def clear_frame_processors_modules() -> None:
global FRAME_PROCESSORS_MODULES

for frame_processor_module in get_frame_processors_modules(roop.globals.frame_processors):
frame_processor_module.clear_frame_processor()
FRAME_PROCESSORS_MODULES = []


Expand Down Expand Up @@ -108,4 +112,4 @@ def get_device() -> str:
return 'cuda'
if 'CoreMLExecutionProvider' in roop.globals.execution_providers:
return 'mps'
return 'cpu'
return 'cpu'
3 changes: 3 additions & 0 deletions roop/uis/__components__/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import onnxruntime

import roop.globals
from roop.face_analyser import clear_face_analyser
from roop.processors.frame.core import list_frame_processors_names, load_frame_processor_module, clear_frame_processors_modules
from roop.uis import core as ui
from roop.uis.typing import Update
Expand Down Expand Up @@ -93,6 +94,8 @@ def sort_frame_processors(frame_processors: List[str]) -> list[str]:


def update_execution_providers(execution_providers: List[str]) -> Update:
clear_face_analyser()
clear_frame_processors_modules()
roop.globals.execution_providers = execution_providers
return gradio.update(value=execution_providers)

Expand Down

0 comments on commit 270139b

Please sign in to comment.