Skip to content

Commit

Permalink
Use different structure
Browse files Browse the repository at this point in the history
  • Loading branch information
henryruhs committed Jun 13, 2023
1 parent 10132d2 commit 9734865
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 28 deletions.
10 changes: 4 additions & 6 deletions roop/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import sys

from roop.frame_processors.core import get_frame_processor_modules

# single thread doubles cuda performance - needs to be set before torch import
if any(arg.startswith('--execution-provider') for arg in sys.argv):
Expand All @@ -26,8 +25,7 @@

import roop.globals
import roop.ui as ui
import roop.frame_processors.face_swapper
import roop.frame_processors.face_enhancer
from roop.processors.frame.core import get_frame_processor_module
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp
from roop.face_analyser import get_one_face

Expand Down Expand Up @@ -181,7 +179,7 @@ def start() -> None:
destroy()
for frame_processor in roop.globals.frame_processors:
update_status(f'{frame_processor} in progress...')
module = get_frame_processor_modules(frame_processor)
module = get_frame_processor_module(frame_processor)
module.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
release_resources()
if is_image(roop.globals.target_path):
Expand All @@ -200,7 +198,7 @@ def start() -> None:
temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
for frame_processor in roop.globals.frame_processors:
update_status(f'{frame_processor} in progress...')
module = get_frame_processor_modules(frame_processor)
module = get_frame_processor_module(frame_processor)
conditional_process_video(roop.globals.source_path, temp_frame_paths, module.process_video)
release_resources()
if roop.globals.keep_fps:
Expand Down Expand Up @@ -236,7 +234,7 @@ def run() -> None:
parse_args()
pre_check()
for frame_processor in roop.globals.frame_processors:
module = get_frame_processor_modules(frame_processor)
module = get_frame_processor_module(frame_processor)
module.pre_check()
limit_resources()
if roop.globals.headless:
Expand Down
16 changes: 0 additions & 16 deletions roop/frame_processors/core.py

This file was deleted.

File renamed without changes.
Empty file.
16 changes: 16 additions & 0 deletions roop/processors/frame/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import sys
import importlib
from typing import Any

FRAME_PROCESSOR_MODULE = None


def get_frame_processor_module(frame_processor: str) -> Any:
global FRAME_PROCESSOR_MODULE

if not FRAME_PROCESSOR_MODULE:
try:
FRAME_PROCESSOR_MODULE = importlib.import_module(f'roop.processors.frame.{frame_processor}')
except ImportError:
sys.exit()
return FRAME_PROCESSOR_MODULE
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@


def pre_check() -> None:
download_directory_path = resolve_relative_path('../../models')
download_directory_path = resolve_relative_path('../../../models')
conditional_download(download_directory_path, ['https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'])


def get_code_former():
global CODE_FORMER
with THREAD_LOCK:
model_path = resolve_relative_path('../../models/codeformer.pth')
model_path = resolve_relative_path('../../../models/codeformer.pth')
if CODE_FORMER is None:
model = torch.load(model_path)['params_ema']
CODE_FORMER = ARCH_REGISTRY.get('CodeFormer')(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def pre_check() -> None:
download_directory_path = resolve_relative_path('../../models')
download_directory_path = resolve_relative_path('../../../models')
conditional_download(download_directory_path, ['https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx'])


Expand All @@ -22,7 +22,7 @@ def get_face_swapper() -> None:

with THREAD_LOCK:
if FACE_SWAPPER is None:
model_path = resolve_relative_path('../../models/inswapper_128.onnx')
model_path = resolve_relative_path('../../../models/inswapper_128.onnx')
FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers)
return FACE_SWAPPER

Expand Down
4 changes: 2 additions & 2 deletions roop/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import roop.globals
from roop.face_analyser import get_one_face
from roop.capturer import get_video_frame, get_video_frame_total
from roop.frame_processors.core import get_frame_processor_modules
from roop.processors.frame.core import get_frame_processor_module
from roop.utilities import is_image, is_video, resolve_relative_path

WINDOW_HEIGHT = 700
Expand Down Expand Up @@ -205,7 +205,7 @@ def init_preview() -> None:
def update_preview(frame_number: int = 0) -> None:
if roop.globals.source_path and roop.globals.target_path:
for frame_processor in roop.globals.frame_processors:
module = get_frame_processor_modules(frame_processor)
module = get_frame_processor_module(frame_processor)
module.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
video_frame = module.process_faces(
get_one_face(cv2.imread(roop.globals.source_path)),
Expand Down

0 comments on commit 9734865

Please sign in to comment.