-
-
Notifications
You must be signed in to change notification settings - Fork 6.1k
/
core.py
70 lines (53 loc) · 1.99 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import importlib
import sys
from typing import Dict, Optional, List, Any
import cv2
import gradio
import roop.globals
import roop.metadata
from roop.typing import Frame
from roop.uis.typing import Component, ComponentName
from roop.utilities import list_module_names
COMPONENTS: Dict[ComponentName, Component] = {}
UI_LAYOUT_METHODS = [
'render',
'listen'
]
def init() -> None:
with gradio.Blocks(theme=get_theme(), title=f'{roop.metadata.name} {roop.metadata.version}') as ui:
for ui_layout in roop.globals.ui_layouts:
ui_layout_module = load_ui_layout_module(ui_layout)
ui_layout_module.render()
ui_layout_module.listen()
ui.launch(show_api=False)
def load_ui_layout_module(ui_layout: str) -> Any:
try:
ui_layout_module = importlib.import_module(f'roop.uis.layouts.{ui_layout}')
for method_name in UI_LAYOUT_METHODS:
if not hasattr(ui_layout_module, method_name):
raise NotImplementedError
except ModuleNotFoundError:
sys.exit(f'UI layout {ui_layout} could not be loaded.')
except NotImplementedError:
sys.exit(f'UI layout {ui_layout} not implemented correctly.')
return ui_layout_module
def get_theme() -> gradio.Theme:
return gradio.themes.Soft(
primary_hue=gradio.themes.colors.red,
secondary_hue=gradio.themes.colors.gray,
font=gradio.themes.GoogleFont('Inter')
).set(
background_fill_primary='*neutral_50',
block_label_text_size='*text_sm',
block_title_text_size='*text_sm'
)
def get_component(name: ComponentName) -> Optional[Component]:
if name in COMPONENTS:
return COMPONENTS[name]
return None
def register_component(name: ComponentName, component: Component) -> None:
COMPONENTS[name] = component
def list_ui_layouts_names() -> Optional[List[str]]:
return list_module_names('roop/uis/__layouts__')
def normalize_frame(frame: Frame) -> Frame:
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)